Source code for gudhi.tensorflow.lower_star_simplex_tree_layer

import numpy               as np
import tensorflow          as tf

#########################################
# Lower star filtration on simplex tree #
#########################################

# The parameters of the model are the vertex function values of the simplex tree.

def _LowerStarSimplexTree(simplextree, filtration, dimensions, homology_coeff_field, persistence_dim_max):
    # Parameters: simplextree (simplex tree on which to compute persistence)
    #             filtration (function values on the vertices of st),
    #             dimensions (homology dimensions),
    #             homology_coeff_field (homology field coefficient)
    
    simplextree.reset_filtration(-np.inf, 0)

    # Assign new filtration values
    for i in range(simplextree.num_vertices()):
        simplextree.assign_filtration([i], filtration[i])
    simplextree.make_filtration_non_decreasing()
    
    # Compute persistence diagram
    simplextree.compute_persistence(homology_coeff_field=homology_coeff_field, persistence_dim_max=persistence_dim_max)
    
    # Get vertex pairs for optimization. First, get all simplex pairs
    pairs = simplextree.lower_star_persistence_generators()
    
    L_indices = []
    for dimension in dimensions:
    
        finite_pairs = pairs[0][dimension] if len(pairs[0]) >= dimension+1 else np.empty(shape=[0,2])
        essential_pairs = pairs[1][dimension] if len(pairs[1]) >= dimension+1 else np.empty(shape=[0,1])
        
        finite_indices = np.array(finite_pairs.flatten(), dtype=np.int32)
        essential_indices = np.array(essential_pairs.flatten(), dtype=np.int32)

        L_indices.append((finite_indices, essential_indices))

    return L_indices

[docs]class LowerStarSimplexTreeLayer(tf.keras.layers.Layer): """ TensorFlow layer for computing lower-star persistence out of a simplex tree """
[docs] def __init__(self, simplextree, homology_dimensions, min_persistence=None, homology_coeff_field=11, persistence_dim_max=False, **kwargs): """ Constructor for the LowerStarSimplexTreeLayer class Parameters: simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n-1, where n is its number of vertices. Note that its filtration values are modified in each call of the class. homology_dimensions (List[int]): list of homology dimensions min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions) homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337. persistence_dim_max (bool): if true, the persistent homology for the maximal dimension in the simplex tree is computed. If false, it is ignored. Default is false. """ super().__init__(dynamic=True, **kwargs) self.dimensions = homology_dimensions self.simplextree = simplextree self.min_persistence = min_persistence if min_persistence is not None else [0. for _ in range(len(self.dimensions))] self.hcf = homology_coeff_field self.pdm = persistence_dim_max assert len(self.min_persistence) == len(self.dimensions)
[docs] def call(self, filtration): """ Compute lower-star persistence diagram associated to a function defined on the vertices of the simplex tree Parameters: F (TensorFlow variable): filter function values over the vertices of the simplex tree. The ith entry of F corresponds to vertex i in self.simplextree Returns: List[Tuple[tf.Tensor,tf.Tensor]]: List of lower-star persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively """ # Don't try to compute gradients for the vertex pairs indices = _LowerStarSimplexTree(self.simplextree, filtration.numpy(), self.dimensions, self.hcf, self.pdm) # Get persistence diagrams self.dgms = [] for idx_dim, dimension in enumerate(self.dimensions): finite_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][0]), [-1,2]) essential_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][1]), [-1,1]) min_pers = self.min_persistence[idx_dim] if min_pers >= 0: persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers) self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm)) else: self.dgms.append((finite_dgm, essential_dgm)) return self.dgms