# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s): Vincent Rouvreau
#
# Copyright (C) 2024 Inria
#
# Modification(s):
# - YYYY/MM Author: Description of the modification
__license__ = "GPL v3"
from typing import Iterable, Literal, Optional, Any
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from joblib import Parallel, delayed
from .. import DelaunayCechComplex, AlphaComplex
# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
# sequenceDiagram
# participant USER
# participant R as CechPersistence
# USER->>R: fit_transform(X)
# Note right of R: homology_dimensions=[i,j]
# R->>thread1: _tranform(X[0])
# R->>thread2: _tranform(X[1])
# Note right of R: ...
# thread1->>R: [array( Hi(X[0]) ), array( Hj(X[0]) )]
# thread2->>R: [array( Hi(X[1]) ), array( Hj(X[1]) )]
# Note right of R: ...
# R->>USER: [[array( Hi(X[0]) ), array( Hj(X[0]) )],<br/> [array( Hi(X[1]) ), array( Hj(X[1]) )],<br/>...]
[docs]
class CechPersistence(BaseEstimator, TransformerMixin):
"""
This is a class for computing the same persistent homology as the Čech complex, while being significantly smaller,
using internally a :class:`~gudhi.DelaunayCechComplex`.
"""
[docs]
def __init__(
self,
homology_dimensions: int | Iterable[int],
precision: Literal["fast", "safe", "exact"] = "safe",
output_squared_values: bool = False,
threshold: float = float("inf"),
homology_coeff_field: int = 11,
n_jobs: Optional[int] = None,
):
"""Constructor for the CechPersistence class.
Parameters:
homology_dimensions: The returned persistence diagrams dimension(s).
Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only
one dimension matters (in other words, when `homology_dimensions` is an int).
precision: Complex precision can be 'fast', 'safe' or 'exact'. Default is 'safe'.
output_squared_values: Square filtration values when `True`. Default is `False` (contrary to its
default value in :class:`~gudhi.DelaunayCechComplex`).
threshold: The maximum filtration value the simplices shall not exceed. Default is set to infinity, and
there is very little point using anything else since it does not save time.
Notice that the filtration values (equal to radii, by default, or squared radii) will be different in
function of `output_squared_values` value (when `False`, by default, or `True`).
homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11.
n_jobs: Number of jobs to run in parallel. `None` (default value) means `n_jobs = 1` unless in a
joblib.parallel_backend context. `-1` means using all processors. cf.
https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html for more details.
"""
self.homology_dimensions = homology_dimensions
self.precision = precision
self.output_squared_values = output_squared_values
self.threshold = threshold
self.homology_coeff_field = homology_coeff_field
self.n_jobs = n_jobs
if self.precision not in ["safe", "fast", "exact"]:
raise ValueError("Unknown precision")
# Done twice (in __init__ and fit), but exception is better the sooner
dim_list = np.asarray(self.homology_dimensions, dtype=int)
if dim_list.ndim not in [0, 1]:
raise ValueError(f"Invalid dimension. Got {self.homology_dimensions=}, expected type=int|Iterable[int].")
[docs]
def fit(self, X, Y=None):
"""
Fit the `CechPersistence` class in function of `homology_dimensions` type.
"""
# Must be in the `fit` part, as `transform` should be const and as `__init__` is not called on a parallel grid
# search for instance
self._dim_list = np.asarray(self.homology_dimensions, dtype=int)
self._unwrap = False
if self._dim_list.ndim == 0:
self._unwrap = True
self._dim_list = self._dim_list.reshape(1)
return self
def __transform(self, inputs):
max_dimension = np.max(self._dim_list) + 1
# alpha complex points take floats
pts = np.asarray(inputs, dtype=np.float64)
delaunay_cech = DelaunayCechComplex(points=pts, precision=self.precision)
stree = delaunay_cech.create_simplex_tree(
max_alpha_square=self.threshold, output_squared_values=self.output_squared_values
)
persistence_dim_max = False
# Specific case where, despite expansion(max_dimension), stree has a lower dimension
if max_dimension > stree.dimension():
persistence_dim_max = True
stree.compute_persistence(
homology_coeff_field=self.homology_coeff_field, persistence_dim_max=persistence_dim_max
)
return [stree.persistence_intervals_in_dimension(dim) for dim in self._dim_list]
[docs]
def get_feature_names_out(self):
"""Provide column names for implementing sklearn's set_output API."""
return [f"H{i}" for i in self._dim_list]
[docs]
class WeightedCechPersistence(BaseEstimator, TransformerMixin):
"""
This is a class for computing the same persistent homology as the Weighted Čech complex, while being significantly
smaller, using internally a Weighted version of :class:`~gudhi.AlphaComplex`.
"""
[docs]
def __init__(
self,
homology_dimensions: int | Iterable[int],
precision: Literal["fast", "safe", "exact"] = "safe",
output_squared_values: bool = True,
threshold: float = float("inf"),
homology_coeff_field: int = 11,
n_jobs: Optional[int] = None,
):
"""Constructor for the WeightedCechPersistence class.
Parameters:
homology_dimensions: The returned persistence diagrams dimension(s).
Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only
one dimension matters (in other words, when `homology_dimensions` is an int).
precision: Complex precision can be 'fast', 'safe' or 'exact'. Default is 'safe'.
output_squared_values: Square filtration values when `True`. Default is `True` (contrary to the unweighted
version :class:`~gudhi.sklearn.CechPersistence`).
threshold: The maximum filtration value the simplices shall not exceed. Default is set to infinity, and
there is very little point using anything else since it does not save time.
Notice that the filtration values (equal to squared radii, by default, or radii) will be different in
function of `output_squared_values` value (when `True`, by default, or `False`).
homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11.
n_jobs: Number of jobs to run in parallel. `None` (default value) means `n_jobs = 1` unless in a
joblib.parallel_backend context. `-1` means using all processors. cf.
https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html for more details.
"""
self.homology_dimensions = homology_dimensions
self.precision = precision
self.output_squared_values = output_squared_values
self.threshold = threshold
self.homology_coeff_field = homology_coeff_field
self.n_jobs = n_jobs
if self.precision not in ["safe", "fast", "exact"]:
raise ValueError("Unknown precision")
# Done twice (in __init__ and fit), but exception is better the sooner
dim_list = np.asarray(self.homology_dimensions, dtype=int)
if dim_list.ndim not in [0, 1]:
raise ValueError(f"Invalid dimension. Got {self.homology_dimensions=}, expected type=int|Iterable[int].")
[docs]
def fit(self, X, Y=None):
"""
Fit the `WeightedCechPersistence` class in function of `homology_dimensions` type.
"""
# Must be in the `fit` part, as `transform` should be const and as `__init__` is not called on a parallel grid
# search for instance
self._dim_list = np.asarray(self.homology_dimensions, dtype=int)
self._unwrap = False
if self._dim_list.ndim == 0:
self._unwrap = True
self._dim_list = self._dim_list.reshape(1)
return self
def __transform(self, inputs):
max_dimension = np.max(self._dim_list) + 1
# alpha complex weights and points take floats
wgts = np.asarray(inputs[:, -1], dtype=np.float64)
pts = np.asarray(inputs[:, :-1], dtype=np.float64)
weighted_alpha = AlphaComplex(points=pts, weights=wgts, precision=self.precision)
stree = weighted_alpha.create_simplex_tree(
max_alpha_square=self.threshold, output_squared_values=self.output_squared_values
)
persistence_dim_max = False
# Specific case where, despite expansion(max_dimension), stree has a lower dimension
if max_dimension > stree.dimension():
persistence_dim_max = True
stree.compute_persistence(
homology_coeff_field=self.homology_coeff_field, persistence_dim_max=persistence_dim_max
)
return [stree.persistence_intervals_in_dimension(dim) for dim in self._dim_list]
[docs]
def get_feature_names_out(self):
"""Provide column names for implementing sklearn's set_output API."""
return [f"H{i}" for i in self._dim_list]