11 #ifndef SLICED_WASSERSTEIN_H_    12 #define SLICED_WASSERSTEIN_H_    15 #include <gudhi/read_persistence_from_file.h>    16 #include <gudhi/common_persistence_representations.h>    17 #include <gudhi/Debug_utils.h>    27 namespace Persistence_representations {
    64   Persistence_diagram diagram;
    67   std::vector<std::vector<double> > projections, projections_diagonal;
    75       double step = pi / this->approx;
    76       int n = diagram.size();
    78       for (
int i = 0; i < this->approx; i++) {
    79         std::vector<double> l, l_diag;
    80         for (
int j = 0; j < n; j++) {
    81           double px = diagram[j].first;
    82           double py = diagram[j].second;
    83           double proj_diag = (px + py) / 2;
    85           l.push_back(px * cos(-pi / 2 + i * step) + py * sin(-pi / 2 + i * step));
    86           l_diag.push_back(proj_diag * cos(-pi / 2 + i * step) + proj_diag * sin(-pi / 2 + i * step));
    89         std::sort(l.begin(), l.end());
    90         std::sort(l_diag.begin(), l_diag.end());
    91         projections.push_back(std::move(l));
    92         projections_diagonal.push_back(std::move(l_diag));
   100   double compute_angle(
const Persistence_diagram& diag, 
int i, 
int j)
 const {
   101     if (diag[i].second == diag[j].second)
   104       return atan((diag[j].first - diag[i].first) / (diag[i].second - diag[j].second));
   109   double compute_int_cos(
double alpha, 
double beta)
 const {
   111     if (alpha >= 0 && alpha <= pi) {
   112       if (cos(alpha) >= 0) {
   113         if (pi / 2 <= beta) {
   114           res = 2 - sin(alpha) - sin(beta);
   116           res = sin(beta) - sin(alpha);
   119         if (1.5 * pi <= beta) {
   120           res = 2 + sin(alpha) + sin(beta);
   122           res = sin(alpha) - sin(beta);
   126     if (alpha >= -pi && alpha <= 0) {
   127       if (cos(alpha) <= 0) {
   128         if (-pi / 2 <= beta) {
   129           res = 2 + sin(alpha) + sin(beta);
   131           res = sin(alpha) - sin(beta);
   134         if (pi / 2 <= beta) {
   135           res = 2 - sin(alpha) - sin(beta);
   137           res = sin(beta) - sin(alpha);
   144   double compute_int(
double theta1, 
double theta2, 
int p, 
int q, 
const Persistence_diagram& diag1,
   145                      const Persistence_diagram& diag2)
 const {
   146     double norm = std::sqrt((diag1[p].first - diag2[q].first) * (diag1[p].first - diag2[q].first) +
   147                             (diag1[p].second - diag2[q].second) * (diag1[p].second - diag2[q].second));
   149     if (diag1[p].first == diag2[q].first)
   150       angle1 = theta1 - pi / 2;
   152       angle1 = theta1 - atan((diag1[p].second - diag2[q].second) / (diag1[p].first - diag2[q].first));
   153     double angle2 = angle1 + theta2 - theta1;
   154     double integral = compute_int_cos(angle1, angle2);
   155     return norm * integral;
   160     GUDHI_CHECK(this->approx == second.approx,
   161                 std::invalid_argument(
"Error: different approx values for representations"));
   163     Persistence_diagram diagram1 = this->diagram;
   164     Persistence_diagram diagram2 = second.diagram;
   167     if (this->approx == -1) {
   170       n1 = diagram1.size();
   171       n2 = diagram2.size();
   172       double min_ordinate = std::numeric_limits<double>::max();
   173       double min_abscissa = std::numeric_limits<double>::max();
   174       double max_ordinate = std::numeric_limits<double>::lowest();
   175       double max_abscissa = std::numeric_limits<double>::lowest();
   176       for (
int i = 0; i < n2; i++) {
   177         min_ordinate = std::min(min_ordinate, diagram2[i].second);
   178         min_abscissa = std::min(min_abscissa, diagram2[i].first);
   179         max_ordinate = std::max(max_ordinate, diagram2[i].second);
   180         max_abscissa = std::max(max_abscissa, diagram2[i].first);
   181         diagram1.emplace_back((diagram2[i].first + diagram2[i].second) / 2,
   182                               (diagram2[i].first + diagram2[i].second) / 2);
   184       for (
int i = 0; i < n1; i++) {
   185         min_ordinate = std::min(min_ordinate, diagram1[i].second);
   186         min_abscissa = std::min(min_abscissa, diagram1[i].first);
   187         max_ordinate = std::max(max_ordinate, diagram1[i].second);
   188         max_abscissa = std::max(max_abscissa, diagram1[i].first);
   189         diagram2.emplace_back((diagram1[i].first + diagram1[i].second) / 2,
   190                               (diagram1[i].first + diagram1[i].second) / 2);
   192       int num_pts_dgm = diagram1.size();
   195       double epsilon = 0.0001;
   196       double thresh_y = (max_ordinate - min_ordinate) * epsilon;
   197       double thresh_x = (max_abscissa - min_abscissa) * epsilon;
   198       std::random_device rd;
   199       std::default_random_engine re(rd());
   200       std::uniform_real_distribution<double> uni(-1, 1);
   201       for (
int i = 0; i < num_pts_dgm; i++) {
   203         diagram1[i].first += u * thresh_x;
   204         diagram1[i].second += u * thresh_y;
   205         diagram2[i].first += u * thresh_x;
   206         diagram2[i].second += u * thresh_y;
   210       std::vector<std::pair<double, std::pair<int, int> > > angles1, angles2;
   211       for (
int i = 0; i < num_pts_dgm; i++) {
   212         for (
int j = i + 1; j < num_pts_dgm; j++) {
   213           double theta1 = compute_angle(diagram1, i, j);
   214           double theta2 = compute_angle(diagram2, i, j);
   215           angles1.emplace_back(theta1, std::pair<int, int>(i, j));
   216           angles2.emplace_back(theta2, std::pair<int, int>(i, j));
   221       std::sort(angles1.begin(), angles1.end(),
   222                 [](
const std::pair<double, std::pair<int, int> >& p1,
   223                    const std::pair<double, std::pair<int, int> >& p2) { 
return (p1.first < p2.first); });
   224       std::sort(angles2.begin(), angles2.end(),
   225                 [](
const std::pair<double, std::pair<int, int> >& p1,
   226                    const std::pair<double, std::pair<int, int> >& p2) { 
return (p1.first < p2.first); });
   229       std::vector<int> orderp1, orderp2;
   230       for (
int i = 0; i < num_pts_dgm; i++) {
   231         orderp1.push_back(i);
   232         orderp2.push_back(i);
   234       std::sort(orderp1.begin(), orderp1.end(), [&](
int i, 
int j) {
   235         if (diagram1[i].second != diagram1[j].second)
   236           return (diagram1[i].second < diagram1[j].second);
   238           return (diagram1[i].first > diagram1[j].first);
   240       std::sort(orderp2.begin(), orderp2.end(), [&](
int i, 
int j) {
   241         if (diagram2[i].second != diagram2[j].second)
   242           return (diagram2[i].second < diagram2[j].second);
   244           return (diagram2[i].first > diagram2[j].first);
   248       std::vector<int> order1(num_pts_dgm);
   249       std::vector<int> order2(num_pts_dgm);
   250       for (
int i = 0; i < num_pts_dgm; i++) {
   251         order1[orderp1[i]] = i;
   252         order2[orderp2[i]] = i;
   256       std::vector<std::vector<std::pair<int, double> > > anglePerm1(num_pts_dgm);
   257       std::vector<std::vector<std::pair<int, double> > > anglePerm2(num_pts_dgm);
   259       int m1 = angles1.size();
   260       for (
int i = 0; i < m1; i++) {
   261         double theta = angles1[i].first;
   262         int p = angles1[i].second.first;
   263         int q = angles1[i].second.second;
   264         anglePerm1[order1[p]].emplace_back(p, theta);
   265         anglePerm1[order1[q]].emplace_back(q, theta);
   272       int m2 = angles2.size();
   273       for (
int i = 0; i < m2; i++) {
   274         double theta = angles2[i].first;
   275         int p = angles2[i].second.first;
   276         int q = angles2[i].second.second;
   277         anglePerm2[order2[p]].emplace_back(p, theta);
   278         anglePerm2[order2[q]].emplace_back(q, theta);
   285       for (
int i = 0; i < num_pts_dgm; i++) {
   286         anglePerm1[order1[i]].emplace_back(i, pi / 2);
   287         anglePerm2[order2[i]].emplace_back(i, pi / 2);
   291       for (
int i = 0; i < num_pts_dgm; i++) {
   292         std::vector<std::pair<int, double> > u, v;
   295         double theta1, theta2;
   300         theta2 = std::min(u[ku].second, v[kv].second);
   301         while (theta1 != pi / 2) {
   302           if (diagram1[u[ku].first].first != diagram2[v[kv].first].first ||
   303               diagram1[u[ku].first].second != diagram2[v[kv].first].second)
   304             if (theta1 != theta2) sw += compute_int(theta1, theta2, u[ku].first, v[kv].first, diagram1, diagram2);
   306           if ((theta2 == u[ku].second) && ku < u.size() - 1) ku++;
   307           if ((theta2 == v[kv].second) && kv < v.size() - 1) kv++;
   308           theta2 = std::min(u[ku].second, v[kv].second);
   312       double step = pi / this->approx;
   313       std::vector<double> v1, v2;
   314       for (
int i = 0; i < this->approx; i++) {
   317         std::merge(this->projections[i].begin(), this->projections[i].end(), second.projections_diagonal[i].begin(),
   318                    second.projections_diagonal[i].end(), std::back_inserter(v1));
   319         std::merge(second.projections[i].begin(), second.projections[i].end(), this->projections_diagonal[i].begin(),
   320                    this->projections_diagonal[i].end(), std::back_inserter(v2));
   324         for (
int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]);
   345       : diagram(_diagram), approx(_approx), sigma(_sigma) {
   357     GUDHI_CHECK(this->sigma == second.sigma,
   358                 std::invalid_argument(
"Error: different sigma values for representations"));
   359     return std::exp(-compute_sliced_wasserstein_distance(second) / (2 * this->sigma * this->sigma));
   370     GUDHI_CHECK(this->sigma == second.sigma,
   371                 std::invalid_argument(
"Error: different sigma values for representations"));
   380 #endif  // SLICED_WASSERSTEIN_H_ Sliced_Wasserstein(const Persistence_diagram &_diagram, double _sigma=1.0, int _approx=10)
Sliced Wasserstein kernel constructor. 
Definition: Sliced_Wasserstein.h:344
A class implementing the Sliced Wasserstein kernel. 
Definition: Sliced_Wasserstein.h:62
double compute_scalar_product(const Sliced_Wasserstein &second) const
Evaluation of the kernel on a pair of diagrams. 
Definition: Sliced_Wasserstein.h:356
Definition: SimplicialComplexForAlpha.h:14
double distance(const Sliced_Wasserstein &second) const
Evaluation of the distance between images of diagrams in the Hilbert space of the kernel...
Definition: Sliced_Wasserstein.h:369