Sliced_Wasserstein.h
1 /* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
2  * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
3  * Author(s): Mathieu Carriere
4  *
5  * Copyright (C) 2018 Inria
6  *
7  * Modification(s):
8  * - YYYY/MM Author: Description of the modification
9  */
10 
11 #ifndef SLICED_WASSERSTEIN_H_
12 #define SLICED_WASSERSTEIN_H_
13 
14 // gudhi include
15 #include <gudhi/read_persistence_from_file.h>
16 #include <gudhi/common_persistence_representations.h>
17 #include <gudhi/Debug_utils.h>
18 
19 #include <vector> // for std::vector<>
20 #include <utility> // for std::pair<>, std::move
21 #include <algorithm> // for std::sort, std::max, std::merge
22 #include <cmath> // for std::abs, std::sqrt
23 #include <stdexcept> // for std::invalid_argument
24 #include <random> // for std::random_device
25 
26 namespace Gudhi {
27 namespace Persistence_representations {
28 
63  protected:
64  Persistence_diagram diagram;
65  int approx;
66  double sigma;
67  std::vector<std::vector<double> > projections, projections_diagonal;
68 
69  // **********************************
70  // Utils.
71  // **********************************
72 
73  void build_rep() {
74  if (approx > 0) {
75  double step = pi / this->approx;
76  int n = diagram.size();
77 
78  for (int i = 0; i < this->approx; i++) {
79  std::vector<double> l, l_diag;
80  for (int j = 0; j < n; j++) {
81  double px = diagram[j].first;
82  double py = diagram[j].second;
83  double proj_diag = (px + py) / 2;
84 
85  l.push_back(px * cos(-pi / 2 + i * step) + py * sin(-pi / 2 + i * step));
86  l_diag.push_back(proj_diag * cos(-pi / 2 + i * step) + proj_diag * sin(-pi / 2 + i * step));
87  }
88 
89  std::sort(l.begin(), l.end());
90  std::sort(l_diag.begin(), l_diag.end());
91  projections.push_back(std::move(l));
92  projections_diagonal.push_back(std::move(l_diag));
93  }
94 
95  diagram.clear();
96  }
97  }
98 
99  // Compute the angle formed by two points of a PD
100  double compute_angle(const Persistence_diagram& diag, int i, int j) const {
101  if (diag[i].second == diag[j].second)
102  return pi / 2;
103  else
104  return atan((diag[j].first - diag[i].first) / (diag[i].second - diag[j].second));
105  }
106 
107  // Compute the integral of |cos()| between alpha and beta, valid only if alpha is in [-pi,pi] and beta-alpha is in
108  // [0,pi]
109  double compute_int_cos(double alpha, double beta) const {
110  double res = 0;
111  if (alpha >= 0 && alpha <= pi) {
112  if (cos(alpha) >= 0) {
113  if (pi / 2 <= beta) {
114  res = 2 - sin(alpha) - sin(beta);
115  } else {
116  res = sin(beta) - sin(alpha);
117  }
118  } else {
119  if (1.5 * pi <= beta) {
120  res = 2 + sin(alpha) + sin(beta);
121  } else {
122  res = sin(alpha) - sin(beta);
123  }
124  }
125  }
126  if (alpha >= -pi && alpha <= 0) {
127  if (cos(alpha) <= 0) {
128  if (-pi / 2 <= beta) {
129  res = 2 + sin(alpha) + sin(beta);
130  } else {
131  res = sin(alpha) - sin(beta);
132  }
133  } else {
134  if (pi / 2 <= beta) {
135  res = 2 - sin(alpha) - sin(beta);
136  } else {
137  res = sin(beta) - sin(alpha);
138  }
139  }
140  }
141  return res;
142  }
143 
144  double compute_int(double theta1, double theta2, int p, int q, const Persistence_diagram& diag1,
145  const Persistence_diagram& diag2) const {
146  double norm = std::sqrt((diag1[p].first - diag2[q].first) * (diag1[p].first - diag2[q].first) +
147  (diag1[p].second - diag2[q].second) * (diag1[p].second - diag2[q].second));
148  double angle1;
149  if (diag1[p].first == diag2[q].first)
150  angle1 = theta1 - pi / 2;
151  else
152  angle1 = theta1 - atan((diag1[p].second - diag2[q].second) / (diag1[p].first - diag2[q].first));
153  double angle2 = angle1 + theta2 - theta1;
154  double integral = compute_int_cos(angle1, angle2);
155  return norm * integral;
156  }
157 
158  // Evaluation of the Sliced Wasserstein Distance between a pair of diagrams.
159  double compute_sliced_wasserstein_distance(const Sliced_Wasserstein& second) const {
160  GUDHI_CHECK(this->approx == second.approx,
161  std::invalid_argument("Error: different approx values for representations"));
162 
163  Persistence_diagram diagram1 = this->diagram;
164  Persistence_diagram diagram2 = second.diagram;
165  double sw = 0;
166 
167  if (this->approx == -1) {
168  // Add projections onto diagonal.
169  int n1, n2;
170  n1 = diagram1.size();
171  n2 = diagram2.size();
172  double min_ordinate = std::numeric_limits<double>::max();
173  double min_abscissa = std::numeric_limits<double>::max();
174  double max_ordinate = std::numeric_limits<double>::lowest();
175  double max_abscissa = std::numeric_limits<double>::lowest();
176  for (int i = 0; i < n2; i++) {
177  min_ordinate = std::min(min_ordinate, diagram2[i].second);
178  min_abscissa = std::min(min_abscissa, diagram2[i].first);
179  max_ordinate = std::max(max_ordinate, diagram2[i].second);
180  max_abscissa = std::max(max_abscissa, diagram2[i].first);
181  diagram1.emplace_back((diagram2[i].first + diagram2[i].second) / 2,
182  (diagram2[i].first + diagram2[i].second) / 2);
183  }
184  for (int i = 0; i < n1; i++) {
185  min_ordinate = std::min(min_ordinate, diagram1[i].second);
186  min_abscissa = std::min(min_abscissa, diagram1[i].first);
187  max_ordinate = std::max(max_ordinate, diagram1[i].second);
188  max_abscissa = std::max(max_abscissa, diagram1[i].first);
189  diagram2.emplace_back((diagram1[i].first + diagram1[i].second) / 2,
190  (diagram1[i].first + diagram1[i].second) / 2);
191  }
192  int num_pts_dgm = diagram1.size();
193 
194  // Slightly perturb the points so that the PDs are in generic positions.
195  double epsilon = 0.0001;
196  double thresh_y = (max_ordinate - min_ordinate) * epsilon;
197  double thresh_x = (max_abscissa - min_abscissa) * epsilon;
198  std::random_device rd;
199  std::default_random_engine re(rd());
200  std::uniform_real_distribution<double> uni(-1, 1);
201  for (int i = 0; i < num_pts_dgm; i++) {
202  double u = uni(re);
203  diagram1[i].first += u * thresh_x;
204  diagram1[i].second += u * thresh_y;
205  diagram2[i].first += u * thresh_x;
206  diagram2[i].second += u * thresh_y;
207  }
208 
209  // Compute all angles in both PDs.
210  std::vector<std::pair<double, std::pair<int, int> > > angles1, angles2;
211  for (int i = 0; i < num_pts_dgm; i++) {
212  for (int j = i + 1; j < num_pts_dgm; j++) {
213  double theta1 = compute_angle(diagram1, i, j);
214  double theta2 = compute_angle(diagram2, i, j);
215  angles1.emplace_back(theta1, std::pair<int, int>(i, j));
216  angles2.emplace_back(theta2, std::pair<int, int>(i, j));
217  }
218  }
219 
220  // Sort angles.
221  std::sort(angles1.begin(), angles1.end(),
222  [](const std::pair<double, std::pair<int, int> >& p1,
223  const std::pair<double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
224  std::sort(angles2.begin(), angles2.end(),
225  [](const std::pair<double, std::pair<int, int> >& p1,
226  const std::pair<double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
227 
228  // Initialize orders of the points of both PDs (given by ordinates when theta = -pi/2).
229  std::vector<int> orderp1, orderp2;
230  for (int i = 0; i < num_pts_dgm; i++) {
231  orderp1.push_back(i);
232  orderp2.push_back(i);
233  }
234  std::sort(orderp1.begin(), orderp1.end(), [&](int i, int j) {
235  if (diagram1[i].second != diagram1[j].second)
236  return (diagram1[i].second < diagram1[j].second);
237  else
238  return (diagram1[i].first > diagram1[j].first);
239  });
240  std::sort(orderp2.begin(), orderp2.end(), [&](int i, int j) {
241  if (diagram2[i].second != diagram2[j].second)
242  return (diagram2[i].second < diagram2[j].second);
243  else
244  return (diagram2[i].first > diagram2[j].first);
245  });
246 
247  // Find the inverses of the orders.
248  std::vector<int> order1(num_pts_dgm);
249  std::vector<int> order2(num_pts_dgm);
250  for (int i = 0; i < num_pts_dgm; i++) {
251  order1[orderp1[i]] = i;
252  order2[orderp2[i]] = i;
253  }
254 
255  // Record all inversions of points in the orders as theta varies along the positive half-disk.
256  std::vector<std::vector<std::pair<int, double> > > anglePerm1(num_pts_dgm);
257  std::vector<std::vector<std::pair<int, double> > > anglePerm2(num_pts_dgm);
258 
259  int m1 = angles1.size();
260  for (int i = 0; i < m1; i++) {
261  double theta = angles1[i].first;
262  int p = angles1[i].second.first;
263  int q = angles1[i].second.second;
264  anglePerm1[order1[p]].emplace_back(p, theta);
265  anglePerm1[order1[q]].emplace_back(q, theta);
266  int a = order1[p];
267  int b = order1[q];
268  order1[p] = b;
269  order1[q] = a;
270  }
271 
272  int m2 = angles2.size();
273  for (int i = 0; i < m2; i++) {
274  double theta = angles2[i].first;
275  int p = angles2[i].second.first;
276  int q = angles2[i].second.second;
277  anglePerm2[order2[p]].emplace_back(p, theta);
278  anglePerm2[order2[q]].emplace_back(q, theta);
279  int a = order2[p];
280  int b = order2[q];
281  order2[p] = b;
282  order2[q] = a;
283  }
284 
285  for (int i = 0; i < num_pts_dgm; i++) {
286  anglePerm1[order1[i]].emplace_back(i, pi / 2);
287  anglePerm2[order2[i]].emplace_back(i, pi / 2);
288  }
289 
290  // Compute the SW distance with the list of inversions.
291  for (int i = 0; i < num_pts_dgm; i++) {
292  std::vector<std::pair<int, double> > u, v;
293  u = anglePerm1[i];
294  v = anglePerm2[i];
295  double theta1, theta2;
296  theta1 = -pi / 2;
297  unsigned int ku, kv;
298  ku = 0;
299  kv = 0;
300  theta2 = std::min(u[ku].second, v[kv].second);
301  while (theta1 != pi / 2) {
302  if (diagram1[u[ku].first].first != diagram2[v[kv].first].first ||
303  diagram1[u[ku].first].second != diagram2[v[kv].first].second)
304  if (theta1 != theta2) sw += compute_int(theta1, theta2, u[ku].first, v[kv].first, diagram1, diagram2);
305  theta1 = theta2;
306  if ((theta2 == u[ku].second) && ku < u.size() - 1) ku++;
307  if ((theta2 == v[kv].second) && kv < v.size() - 1) kv++;
308  theta2 = std::min(u[ku].second, v[kv].second);
309  }
310  }
311  } else {
312  double step = pi / this->approx;
313  std::vector<double> v1, v2;
314  for (int i = 0; i < this->approx; i++) {
315  v1.clear();
316  v2.clear();
317  std::merge(this->projections[i].begin(), this->projections[i].end(), second.projections_diagonal[i].begin(),
318  second.projections_diagonal[i].end(), std::back_inserter(v1));
319  std::merge(second.projections[i].begin(), second.projections[i].end(), this->projections_diagonal[i].begin(),
320  this->projections_diagonal[i].end(), std::back_inserter(v2));
321 
322  int n = v1.size();
323  double f = 0;
324  for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]);
325  sw += f * step;
326  }
327  }
328 
329  return sw / pi;
330  }
331 
332  public:
344  Sliced_Wasserstein(const Persistence_diagram& _diagram, double _sigma = 1.0, int _approx = 10)
345  : diagram(_diagram), approx(_approx), sigma(_sigma) {
346  build_rep();
347  }
348 
356  double compute_scalar_product(const Sliced_Wasserstein& second) const {
357  GUDHI_CHECK(this->sigma == second.sigma,
358  std::invalid_argument("Error: different sigma values for representations"));
359  return std::exp(-compute_sliced_wasserstein_distance(second) / (2 * this->sigma * this->sigma));
360  }
361 
369  double distance(const Sliced_Wasserstein& second) const {
370  GUDHI_CHECK(this->sigma == second.sigma,
371  std::invalid_argument("Error: different sigma values for representations"));
372  return std::sqrt(this->compute_scalar_product(*this) + second.compute_scalar_product(second) -
373  2 * this->compute_scalar_product(second));
374  }
375 
376 }; // class Sliced_Wasserstein
377 } // namespace Persistence_representations
378 } // namespace Gudhi
379 
380 #endif // SLICED_WASSERSTEIN_H_
Sliced_Wasserstein(const Persistence_diagram &_diagram, double _sigma=1.0, int _approx=10)
Sliced Wasserstein kernel constructor.
Definition: Sliced_Wasserstein.h:344
A class implementing the Sliced Wasserstein kernel.
Definition: Sliced_Wasserstein.h:62
double compute_scalar_product(const Sliced_Wasserstein &second) const
Evaluation of the kernel on a pair of diagrams.
Definition: Sliced_Wasserstein.h:356
Definition: SimplicialComplexForAlpha.h:14
double distance(const Sliced_Wasserstein &second) const
Evaluation of the distance between images of diagrams in the Hilbert space of the kernel...
Definition: Sliced_Wasserstein.h:369
GUDHI  Version 3.4.1  - C++ library for Topological Data Analysis (TDA) and Higher Dimensional Geometry Understanding.  - Copyright : MIT Generated on Fri Jan 22 2021 09:41:15 for GUDHI by Doxygen 1.8.13