Loading...
Searching...
No Matches
Sliced_Wasserstein.h
1/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
2 * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
3 * Author(s): Mathieu Carriere
4 *
5 * Copyright (C) 2018 Inria
6 *
7 * Modification(s):
8 * - YYYY/MM Author: Description of the modification
9 */
10
11#ifndef SLICED_WASSERSTEIN_H_
12#define SLICED_WASSERSTEIN_H_
13
14// gudhi include
15#include <gudhi/read_persistence_from_file.h>
16#include <gudhi/common_persistence_representations.h>
17#include <gudhi/Debug_utils.h>
18
19#include <vector> // for std::vector<>
20#include <utility> // for std::pair<>, std::move
21#include <algorithm> // for std::sort, std::max, std::merge
22#include <cmath> // for std::abs, std::sqrt
23#include <stdexcept> // for std::invalid_argument
24#include <random> // for std::random_device
25
26namespace Gudhi {
27namespace Persistence_representations {
28
63 protected:
64 Persistence_diagram diagram;
65 int approx;
66 double sigma;
67 std::vector<std::vector<double> > projections, projections_diagonal;
68
69 // **********************************
70 // Utils.
71 // **********************************
72
73 void build_rep() {
74 if (approx > 0) {
75 double step = pi / this->approx;
76 int n = diagram.size();
77
78 for (int i = 0; i < this->approx; i++) {
79 std::vector<double> l, l_diag;
80 for (int j = 0; j < n; j++) {
81 double px = diagram[j].first;
82 double py = diagram[j].second;
83 double proj_diag = (px + py) / 2;
84
85 l.push_back(px * cos(-pi / 2 + i * step) + py * sin(-pi / 2 + i * step));
86 l_diag.push_back(proj_diag * cos(-pi / 2 + i * step) + proj_diag * sin(-pi / 2 + i * step));
87 }
88
89 std::sort(l.begin(), l.end());
90 std::sort(l_diag.begin(), l_diag.end());
91 projections.push_back(std::move(l));
92 projections_diagonal.push_back(std::move(l_diag));
93 }
94
95 diagram.clear();
96 }
97 }
98
99 // Compute the angle formed by two points of a PD
100 double compute_angle(const Persistence_diagram& diag, int i, int j) const {
101 if (diag[i].second == diag[j].second)
102 return pi / 2;
103 else
104 return atan((diag[j].first - diag[i].first) / (diag[i].second - diag[j].second));
105 }
106
107 // Compute the integral of |cos()| between alpha and beta, valid only if alpha is in [-pi,pi] and beta-alpha is in
108 // [0,pi]
109 double compute_int_cos(double alpha, double beta) const {
110 double res = 0;
111 if (alpha >= 0 && alpha <= pi) {
112 if (cos(alpha) >= 0) {
113 if (pi / 2 <= beta) {
114 res = 2 - sin(alpha) - sin(beta);
115 } else {
116 res = sin(beta) - sin(alpha);
117 }
118 } else {
119 if (1.5 * pi <= beta) {
120 res = 2 + sin(alpha) + sin(beta);
121 } else {
122 res = sin(alpha) - sin(beta);
123 }
124 }
125 }
126 if (alpha >= -pi && alpha <= 0) {
127 if (cos(alpha) <= 0) {
128 if (-pi / 2 <= beta) {
129 res = 2 + sin(alpha) + sin(beta);
130 } else {
131 res = sin(alpha) - sin(beta);
132 }
133 } else {
134 if (pi / 2 <= beta) {
135 res = 2 - sin(alpha) - sin(beta);
136 } else {
137 res = sin(beta) - sin(alpha);
138 }
139 }
140 }
141 return res;
142 }
143
144 double compute_int(double theta1, double theta2, int p, int q, const Persistence_diagram& diag1,
145 const Persistence_diagram& diag2) const {
146 double norm = std::sqrt((diag1[p].first - diag2[q].first) * (diag1[p].first - diag2[q].first) +
147 (diag1[p].second - diag2[q].second) * (diag1[p].second - diag2[q].second));
148 double angle1;
149 if (diag1[p].first == diag2[q].first)
150 angle1 = theta1 - pi / 2;
151 else
152 angle1 = theta1 - atan((diag1[p].second - diag2[q].second) / (diag1[p].first - diag2[q].first));
153 double angle2 = angle1 + theta2 - theta1;
154 double integral = compute_int_cos(angle1, angle2);
155 return norm * integral;
156 }
157
158 // Evaluation of the Sliced Wasserstein Distance between a pair of diagrams.
159 double compute_sliced_wasserstein_distance(const Sliced_Wasserstein& second) const {
160 GUDHI_CHECK(this->approx == second.approx,
161 std::invalid_argument("Error: different approx values for representations"));
162
163 Persistence_diagram diagram1 = this->diagram;
164 Persistence_diagram diagram2 = second.diagram;
165 double sw = 0;
166
167 if (this->approx == -1) {
168 // Add projections onto diagonal.
169 int n1, n2;
170 n1 = diagram1.size();
171 n2 = diagram2.size();
172 double min_ordinate = std::numeric_limits<double>::max();
173 double min_abscissa = std::numeric_limits<double>::max();
174 double max_ordinate = std::numeric_limits<double>::lowest();
175 double max_abscissa = std::numeric_limits<double>::lowest();
176 for (int i = 0; i < n2; i++) {
177 min_ordinate = std::min(min_ordinate, diagram2[i].second);
178 min_abscissa = std::min(min_abscissa, diagram2[i].first);
179 max_ordinate = std::max(max_ordinate, diagram2[i].second);
180 max_abscissa = std::max(max_abscissa, diagram2[i].first);
181 diagram1.emplace_back((diagram2[i].first + diagram2[i].second) / 2,
182 (diagram2[i].first + diagram2[i].second) / 2);
183 }
184 for (int i = 0; i < n1; i++) {
185 min_ordinate = std::min(min_ordinate, diagram1[i].second);
186 min_abscissa = std::min(min_abscissa, diagram1[i].first);
187 max_ordinate = std::max(max_ordinate, diagram1[i].second);
188 max_abscissa = std::max(max_abscissa, diagram1[i].first);
189 diagram2.emplace_back((diagram1[i].first + diagram1[i].second) / 2,
190 (diagram1[i].first + diagram1[i].second) / 2);
191 }
192 int num_pts_dgm = diagram1.size();
193
194 // Slightly perturb the points so that the PDs are in generic positions.
195 double epsilon = 0.0001;
196 double thresh_y = (max_ordinate - min_ordinate) * epsilon;
197 double thresh_x = (max_abscissa - min_abscissa) * epsilon;
198 std::random_device rd;
199 std::default_random_engine re(rd());
200 std::uniform_real_distribution<double> uni(-1, 1);
201 for (int i = 0; i < num_pts_dgm; i++) {
202 double u = uni(re);
203 diagram1[i].first += u * thresh_x;
204 diagram1[i].second += u * thresh_y;
205 diagram2[i].first += u * thresh_x;
206 diagram2[i].second += u * thresh_y;
207 }
208
209 // Compute all angles in both PDs.
210 std::vector<std::pair<double, std::pair<int, int> > > angles1, angles2;
211 for (int i = 0; i < num_pts_dgm; i++) {
212 for (int j = i + 1; j < num_pts_dgm; j++) {
213 double theta1 = compute_angle(diagram1, i, j);
214 double theta2 = compute_angle(diagram2, i, j);
215 angles1.emplace_back(theta1, std::pair<int, int>(i, j));
216 angles2.emplace_back(theta2, std::pair<int, int>(i, j));
217 }
218 }
219
220 // Sort angles.
221 std::sort(angles1.begin(), angles1.end(),
222 [](const std::pair<double, std::pair<int, int> >& p1,
223 const std::pair<double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
224 std::sort(angles2.begin(), angles2.end(),
225 [](const std::pair<double, std::pair<int, int> >& p1,
226 const std::pair<double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
227
228 // Initialize orders of the points of both PDs (given by ordinates when theta = -pi/2).
229 std::vector<int> orderp1, orderp2;
230 for (int i = 0; i < num_pts_dgm; i++) {
231 orderp1.push_back(i);
232 orderp2.push_back(i);
233 }
234 std::sort(orderp1.begin(), orderp1.end(), [&](int i, int j) {
235 if (diagram1[i].second != diagram1[j].second)
236 return (diagram1[i].second < diagram1[j].second);
237 else
238 return (diagram1[i].first > diagram1[j].first);
239 });
240 std::sort(orderp2.begin(), orderp2.end(), [&](int i, int j) {
241 if (diagram2[i].second != diagram2[j].second)
242 return (diagram2[i].second < diagram2[j].second);
243 else
244 return (diagram2[i].first > diagram2[j].first);
245 });
246
247 // Find the inverses of the orders.
248 std::vector<int> order1(num_pts_dgm);
249 std::vector<int> order2(num_pts_dgm);
250 for (int i = 0; i < num_pts_dgm; i++) {
251 order1[orderp1[i]] = i;
252 order2[orderp2[i]] = i;
253 }
254
255 // Record all inversions of points in the orders as theta varies along the positive half-disk.
256 std::vector<std::vector<std::pair<int, double> > > anglePerm1(num_pts_dgm);
257 std::vector<std::vector<std::pair<int, double> > > anglePerm2(num_pts_dgm);
258
259 int m1 = angles1.size();
260 for (int i = 0; i < m1; i++) {
261 double theta = angles1[i].first;
262 int p = angles1[i].second.first;
263 int q = angles1[i].second.second;
264 anglePerm1[order1[p]].emplace_back(p, theta);
265 anglePerm1[order1[q]].emplace_back(q, theta);
266 int a = order1[p];
267 int b = order1[q];
268 order1[p] = b;
269 order1[q] = a;
270 }
271
272 int m2 = angles2.size();
273 for (int i = 0; i < m2; i++) {
274 double theta = angles2[i].first;
275 int p = angles2[i].second.first;
276 int q = angles2[i].second.second;
277 anglePerm2[order2[p]].emplace_back(p, theta);
278 anglePerm2[order2[q]].emplace_back(q, theta);
279 int a = order2[p];
280 int b = order2[q];
281 order2[p] = b;
282 order2[q] = a;
283 }
284
285 for (int i = 0; i < num_pts_dgm; i++) {
286 anglePerm1[order1[i]].emplace_back(i, pi / 2);
287 anglePerm2[order2[i]].emplace_back(i, pi / 2);
288 }
289
290 // Compute the SW distance with the list of inversions.
291 for (int i = 0; i < num_pts_dgm; i++) {
292 std::vector<std::pair<int, double> > u, v;
293 u = anglePerm1[i];
294 v = anglePerm2[i];
295 double theta1, theta2;
296 theta1 = -pi / 2;
297 unsigned int ku, kv;
298 ku = 0;
299 kv = 0;
300 theta2 = std::min(u[ku].second, v[kv].second);
301 while (theta1 != pi / 2) {
302 if (diagram1[u[ku].first].first != diagram2[v[kv].first].first ||
303 diagram1[u[ku].first].second != diagram2[v[kv].first].second)
304 if (theta1 != theta2) sw += compute_int(theta1, theta2, u[ku].first, v[kv].first, diagram1, diagram2);
305 theta1 = theta2;
306 if ((theta2 == u[ku].second) && ku < u.size() - 1) ku++;
307 if ((theta2 == v[kv].second) && kv < v.size() - 1) kv++;
308 theta2 = std::min(u[ku].second, v[kv].second);
309 }
310 }
311 } else {
312 double step = pi / this->approx;
313 std::vector<double> v1, v2;
314 for (int i = 0; i < this->approx; i++) {
315 v1.clear();
316 v2.clear();
317 std::merge(this->projections[i].begin(), this->projections[i].end(), second.projections_diagonal[i].begin(),
318 second.projections_diagonal[i].end(), std::back_inserter(v1));
319 std::merge(second.projections[i].begin(), second.projections[i].end(), this->projections_diagonal[i].begin(),
320 this->projections_diagonal[i].end(), std::back_inserter(v2));
321
322 int n = v1.size();
323 double f = 0;
324 for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]);
325 sw += f * step;
326 }
327 }
328
329 return sw / pi;
330 }
331
332 public:
344 Sliced_Wasserstein(const Persistence_diagram& _diagram, double _sigma = 1.0, int _approx = 10)
345 : diagram(_diagram), approx(_approx), sigma(_sigma) {
346 build_rep();
347 }
348
356 double compute_scalar_product(const Sliced_Wasserstein& second) const {
357 GUDHI_CHECK(this->sigma == second.sigma,
358 std::invalid_argument("Error: different sigma values for representations"));
359 return std::exp(-compute_sliced_wasserstein_distance(second) / (2 * this->sigma * this->sigma));
360 }
361
369 double distance(const Sliced_Wasserstein& second) const {
370 GUDHI_CHECK(this->sigma == second.sigma,
371 std::invalid_argument("Error: different sigma values for representations"));
372 return std::sqrt(this->compute_scalar_product(*this) + second.compute_scalar_product(second) -
373 2 * this->compute_scalar_product(second));
374 }
375
376}; // class Sliced_Wasserstein
377} // namespace Persistence_representations
378} // namespace Gudhi
379
380#endif // SLICED_WASSERSTEIN_H_
A class implementing the Sliced Wasserstein kernel.
Definition: Sliced_Wasserstein.h:62
Sliced_Wasserstein(const Persistence_diagram &_diagram, double _sigma=1.0, int _approx=10)
Sliced Wasserstein kernel constructor. , Real_valued_topological_data, Topological_data_with_scalar_p...
Definition: Sliced_Wasserstein.h:344
double compute_scalar_product(const Sliced_Wasserstein &second) const
Evaluation of the kernel on a pair of diagrams.
Definition: Sliced_Wasserstein.h:356
double distance(const Sliced_Wasserstein &second) const
Evaluation of the distance between images of diagrams in the Hilbert space of the kernel.
Definition: Sliced_Wasserstein.h:369