# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s): Mathieu Carrière, Vincent Rouvreau
#
# Copyright (C) 2018-2019 Inria
#
# Modification(s):
# - 2021/10 Vincent Rouvreau: Add DimensionSelector
# - YYYY/MM Author: Description of the modification
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler
#############################################
# Utils #####################################
#############################################
def _maybe_fit_transform(obj, attr, diag):
"""
In __call__, use transform on the object itself if it has been fitted,
otherwise fit_transform on a clone of the object so it doesn't affect future calls.
"""
if hasattr(obj, attr):
result = obj.transform([diag])
else:
result = obj.__class__(**obj.get_params()).fit_transform([diag])
return result[0]
[docs]
class Clamping(BaseEstimator, TransformerMixin):
"""
This is a class for clamping a list of values. It is not meant to be called directly on (a list of) persistence diagrams, but it is rather meant to be used as a parameter for the DiagramScaler class. As such it has the same methods and purpose as common scalers from sklearn.preprocessing such as MinMaxScaler, RobustScaler, StandardScaler, etc. A typical use would be for instance if you want to clamp abscissae or ordinates (or both) of persistence diagrams within a pre-defined interval.
"""
[docs]
def __init__(self, minimum=-np.inf, maximum=np.inf):
"""
Constructor for the Clamping class.
Parameters:
limit (float): clamping value (default np.inf).
"""
self.minimum = minimum
self.maximum = maximum
[docs]
def fit(self, X, y=None):
"""
Fit the Clamping class on a list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline).
Parameters:
X (numpy array of size n): input values.
y (n x 1 array): value labels (unused).
"""
return self
#############################################
# Preprocessing #############################
#############################################
[docs]
class DiagramScaler(BaseEstimator, TransformerMixin):
"""
This is a class for preprocessing persistence diagrams with a given list of scalers, such as those included in scikit-learn.
"""
[docs]
def __init__(self, use=False, scalers=[]):
"""
Constructor for the DiagramScaler class.
Parameters:
use (bool): whether to use the class or not (default False).
scalers (list of classes): list of scalers to be fit on the persistence diagrams (default []). Each element of the list is a tuple with two elements: the first one is a list of coordinates, and the second one is a scaler (i.e. a class with fit() and transform() methods) that is going to be applied to these coordinates. Common scalers can be found in the scikit-learn library (such as MinMaxScaler for instance).
"""
self.scalers = scalers
self.use = use
[docs]
def fit(self, X, y=None):
"""
Fit the DiagramScaler class on a list of persistence diagrams: persistence diagrams are concatenated in a big numpy array, and scalers are fit (by calling their fit() method) on their corresponding coordinates in this big array.
Parameters:
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
self.is_fitted_ = True
if self.use:
if len(X) == 1:
P = X[0]
else:
P = np.concatenate(X,0)
for (indices, scaler) in self.scalers:
scaler.fit(np.reshape(P[:,indices], [-1, 1]))
return self
[docs]
def __call__(self, diag):
"""
Apply DiagramScaler on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
n x 2 numpy array: transformed persistence diagram.
"""
return _maybe_fit_transform(self, 'is_fitted_', diag)
[docs]
class Padding(BaseEstimator, TransformerMixin):
"""
This is a class for padding a list of persistence diagrams with dummy points, so that all persistence diagrams end up with the same number of points.
"""
[docs]
def __init__(self, use=False):
"""
Constructor for the Padding class.
Parameters:
use (bool): whether to use the class or not (default False).
"""
self.use = use
[docs]
def fit(self, X, y=None):
"""
Fit the Padding class on a list of persistence diagrams.
Parameters:
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
self.max_pts_ = max(len(diag) for diag in X)
return self
[docs]
def __call__(self, diag):
"""
Apply Padding on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
n x 2 numpy array: padded persistence diagram.
"""
return _maybe_fit_transform(self, 'max_pts_', diag)
[docs]
class ProminentPoints(BaseEstimator, TransformerMixin):
"""
This is a class for removing points that are close or far from the diagonal in persistence diagrams. If persistence diagrams are n x 2 numpy arrays (i.e. persistence diagrams with ordinary features), points are ordered and thresholded by distance-to-diagonal. If persistence diagrams are n x 1 numpy arrays (i.e. persistence diagrams with essential features), points are not ordered and thresholded by first coordinate.
"""
[docs]
def __init__(self, use=False, num_pts=10, threshold=-1, location="upper"):
"""
Constructor for the ProminentPoints class.
Parameters:
use (bool): whether to use the class or not (default False).
location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal.
num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal.
threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
"""
self.num_pts = num_pts
self.threshold = threshold
self.use = use
self.location = location
[docs]
def fit(self, X, y=None):
"""
Fit the ProminentPoints class on a list of persistence diagrams (this function actually does nothing but is useful when ProminentPoints is included in a scikit-learn Pipeline).
Parameters:
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
return self
[docs]
def __call__(self, diag):
"""
Apply ProminentPoints on a single persistence diagram and outputs the result.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
n x 2 numpy array: thresholded persistence diagram.
"""
return self.transform([diag])[0]
[docs]
class DiagramSelector(BaseEstimator, TransformerMixin):
"""
This is a class for extracting finite or essential points in persistence diagrams.
"""
[docs]
def __init__(self, use=False, limit=np.inf, point_type="finite"):
"""
Constructor for the DiagramSelector class.
Parameters:
use (bool): whether to use the class or not (default False).
limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf).
point_type (string): either "finite" or "essential". The type of the points that are going to be extracted.
"""
self.use, self.limit, self.point_type = use, limit, point_type
[docs]
def fit(self, X, y=None):
"""
Fit the DiagramSelector class on a list of persistence diagrams (this function actually does nothing but is useful when DiagramSelector is included in a scikit-learn Pipeline).
Parameters:
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
return self
[docs]
def __call__(self, diag):
"""
Apply DiagramSelector on a single persistence diagram and outputs the result.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
n x 2 numpy array: extracted persistence diagram.
"""
return self.transform([diag])[0]
# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
# sequenceDiagram
# USER->>DimensionSelector: fit_transform(<br/>[[array( Hi(X0) ), array( Hj(X0) ), ...],<br/> [array( Hi(X1) ), array( Hj(X1) ), ...],<br/> ...])
# DimensionSelector->>thread1: _transform([array( Hi(X0) ), array( Hj(X0) )], ...)
# DimensionSelector->>thread2: _transform([array( Hi(X1) ), array( Hj(X1) )], ...)
# Note right of DimensionSelector: ...
# thread1->>DimensionSelector: array( Hn(X0) )
# thread2->>DimensionSelector: array( Hn(X1) )
# Note right of DimensionSelector: ...
# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...]
[docs]
class DimensionSelector(BaseEstimator, TransformerMixin):
"""
This is a class to select persistence diagrams in a specific dimension from its index.
"""
[docs]
def __init__(self, index=0):
"""
Constructor for the DimensionSelector class.
Parameters:
index (int): The returned persistence diagrams dimension index. Default value is `0`.
"""
self.index = index
[docs]
def fit(self, X, Y=None):
"""
Nothing to be done, but useful when included in a scikit-learn Pipeline.
"""
return self