import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler
# Utils #####################################
def _maybe_fit_transform(obj, attr, diag):
In __call__, use transform on the object itself if it has been fitted,
otherwise fit_transform on a clone of the object so it doesn't affect future calls.
if hasattr(obj, attr):
result = obj.transform([diag])
result = obj.__class__(**obj.get_params()).fit_transform([diag])
return result[0]
class Clamping(BaseEstimator, TransformerMixin):
This is a class for clamping a list of values. It is not meant to be called directly on (a list of) persistence diagrams, but it is rather meant to be used as a parameter for the DiagramScaler class. As such it has the same methods and purpose as common scalers from sklearn.preprocessing such as MinMaxScaler, RobustScaler, StandardScaler, etc. A typical use would be for instance if you want to clamp abscissae or ordinates (or both) of persistence diagrams within a pre-defined interval.
def __init__(self, minimum=-np.inf, maximum=np.inf):
Constructor for the Clamping class.
limit (float): clamping value (default np.inf).
self.minimum = minimum
self.maximum = maximum
def fit(self, X, y=None):
Fit the Clamping class on a list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline).
X (numpy array of size n): input values.
y (n x 1 array): value labels (unused).
return self
# Preprocessing #############################
class DiagramScaler(BaseEstimator, TransformerMixin):
This is a class for preprocessing persistence diagrams with a given list of scalers, such as those included in scikit-learn.
def __init__(self, use=False, scalers=[]):
Constructor for the DiagramScaler class.
use (bool): whether to use the class or not (default False).
scalers (list of classes): list of scalers to be fit on the persistence diagrams (default []). Each element of the list is a tuple with two elements: the first one is a list of coordinates, and the second one is a scaler (i.e. a class with fit() and transform() methods) that is going to be applied to these coordinates. Common scalers can be found in the scikit-learn library (such as MinMaxScaler for instance).
self.scalers = scalers
self.use = use
def fit(self, X, y=None):
Fit the DiagramScaler class on a list of persistence diagrams: persistence diagrams are concatenated in a big numpy array, and scalers are fit (by calling their fit() method) on their corresponding coordinates in this big array.
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
self.is_fitted_ = True
if self.use:
if len(X) == 1:
P = X[0]
P = np.concatenate(X,0)
for (indices, scaler) in self.scalers:
scaler.fit(np.reshape(P[:,indices], [-1, 1]))
return self
def __call__(self, diag):
Apply DiagramScaler on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
diag (n x 2 numpy array): input persistence diagram.
n x 2 numpy array: transformed persistence diagram.
return _maybe_fit_transform(self, 'is_fitted_', diag)
class Padding(BaseEstimator, TransformerMixin):
This is a class for padding a list of persistence diagrams with dummy points, so that all persistence diagrams end up with the same number of points.
def __init__(self, use=False):
Constructor for the Padding class.
use (bool): whether to use the class or not (default False).
self.use = use
def fit(self, X, y=None):
Fit the Padding class on a list of persistence diagrams.
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
self.max_pts_ = max(len(diag) for diag in X)
return self
def __call__(self, diag):
Apply Padding on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
diag (n x 2 numpy array): input persistence diagram.
n x 2 numpy array: padded persistence diagram.
return _maybe_fit_transform(self, 'max_pts_', diag)
class ProminentPoints(BaseEstimator, TransformerMixin):
This is a class for removing points that are close or far from the diagonal in persistence diagrams. If persistence diagrams are n x 2 numpy arrays (i.e. persistence diagrams with ordinary features), points are ordered and thresholded by distance-to-diagonal. If persistence diagrams are n x 1 numpy arrays (i.e. persistence diagrams with essential features), points are not ordered and thresholded by first coordinate.
def __init__(self, use=False, num_pts=10, threshold=-1, location="upper"):
Constructor for the ProminentPoints class.
use (bool): whether to use the class or not (default False).
location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal.
num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal.
threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
self.num_pts = num_pts
self.threshold = threshold
self.use = use
self.location = location
def fit(self, X, y=None):
Fit the ProminentPoints class on a list of persistence diagrams (this function actually does nothing but is useful when ProminentPoints is included in a scikit-learn Pipeline).
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
return self
def __call__(self, diag):
Apply ProminentPoints on a single persistence diagram and outputs the result.
diag (n x 2 numpy array): input persistence diagram.
n x 2 numpy array: thresholded persistence diagram.
return self.transform([diag])[0]
class DiagramSelector(BaseEstimator, TransformerMixin):
This is a class for extracting finite or essential points in persistence diagrams.
def __init__(self, use=False, limit=np.inf, point_type="finite"):
Constructor for the DiagramSelector class.
use (bool): whether to use the class or not (default False).
limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf).
point_type (string): either "finite" or "essential". The type of the points that are going to be extracted.
self.use, self.limit, self.point_type = use, limit, point_type
def fit(self, X, y=None):
Fit the DiagramSelector class on a list of persistence diagrams (this function actually does nothing but is useful when DiagramSelector is included in a scikit-learn Pipeline).
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
return self
def __call__(self, diag):
Apply DiagramSelector on a single persistence diagram and outputs the result.
diag (n x 2 numpy array): input persistence diagram.
n x 2 numpy array: extracted persistence diagram.
return self.transform([diag])[0]
class DimensionSelector(BaseEstimator, TransformerMixin):
This is a class to select persistence diagrams in a specific dimension from its index.
def __init__(self, index=0):
Constructor for the DimensionSelector class.
index (int): The returned persistence diagrams dimension index. Default value is `0`.
self.index = index
def fit(self, X, Y=None):
Nothing to be done, but useful when included in a scikit-learn Pipeline.
return self