import numpy as np
import tensorflow as tf
#########################################
# Lower star filtration on simplex tree #
#########################################
# The parameters of the model are the vertex function values of the simplex tree.
def _LowerStarSimplexTree(simplextree, filtration, dimensions, homology_coeff_field, persistence_dim_max):
# Parameters: simplextree (simplex tree on which to compute persistence)
# filtration (function values on the vertices of st),
# dimensions (homology dimensions),
# homology_coeff_field (homology field coefficient)
simplextree.reset_filtration(-np.inf, 0)
# Assign new filtration values
for i in range(simplextree.num_vertices()):
simplextree.assign_filtration([i], filtration[i])
simplextree.make_filtration_non_decreasing()
# Compute persistence diagram
simplextree.compute_persistence(homology_coeff_field=homology_coeff_field, persistence_dim_max=persistence_dim_max)
# Get vertex pairs for optimization. First, get all simplex pairs
pairs = simplextree.lower_star_persistence_generators()
L_indices = []
for dimension in dimensions:
finite_pairs = pairs[0][dimension] if len(pairs[0]) >= dimension+1 else np.empty(shape=[0,2])
essential_pairs = pairs[1][dimension] if len(pairs[1]) >= dimension+1 else np.empty(shape=[0,1])
finite_indices = np.array(finite_pairs.flatten(), dtype=np.int32)
essential_indices = np.array(essential_pairs.flatten(), dtype=np.int32)
L_indices.append((finite_indices, essential_indices))
return L_indices
[docs]
class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
"""
TensorFlow layer for computing lower-star persistence out of a simplex tree
"""
[docs]
def __init__(self, simplextree, homology_dimensions, min_persistence=None, homology_coeff_field=11, persistence_dim_max=False, **kwargs):
"""
Constructor for the LowerStarSimplexTreeLayer class
Parameters:
simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n-1, where n is its number of vertices. Note that its filtration values are modified in each call of the class.
homology_dimensions (List[int]): list of homology dimensions
min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337.
persistence_dim_max (bool): if true, the persistent homology for the maximal dimension in the simplex tree is computed. If false, it is ignored. Default is false.
"""
super().__init__(**kwargs)
self.dimensions = homology_dimensions
self.simplextree = simplextree
self.min_persistence = min_persistence if min_persistence is not None else [0. for _ in range(len(self.dimensions))]
self.hcf = homology_coeff_field
self.pdm = persistence_dim_max
assert len(self.min_persistence) == len(self.dimensions)
[docs]
def call(self, filtration):
"""
Compute lower-star persistence diagram associated to a function defined on the vertices of the simplex tree
Parameters:
F (TensorFlow variable): filter function values over the vertices of the simplex tree. The ith entry of F corresponds to vertex i in self.simplextree
Returns:
List[Tuple[tf.Tensor,tf.Tensor]]: List of lower-star persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively
"""
# Don't try to compute gradients for the vertex pairs
indices = _LowerStarSimplexTree(self.simplextree, filtration.numpy(), self.dimensions, self.hcf, self.pdm)
# Get persistence diagrams
self.dgms = []
for idx_dim, dimension in enumerate(self.dimensions):
finite_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][0]), [-1,2])
essential_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][1]), [-1,1])
min_pers = self.min_persistence[idx_dim]
if min_pers >= 0:
persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
else:
self.dgms.append((finite_dgm, essential_dgm))
return self.dgms