Loading...
Searching...
No Matches
Sliced_Wasserstein.h
1/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
2 * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
3 * Author(s): Mathieu Carriere
4 *
5 * Copyright (C) 2018 Inria
6 *
7 * Modification(s):
8 * - YYYY/MM Author: Description of the modification
9 */
10
11#ifndef SLICED_WASSERSTEIN_H_
12#define SLICED_WASSERSTEIN_H_
13
14// standard include
15#include <vector> // for std::vector<>
16#include <utility> // for std::pair<>, std::move
17#include <algorithm> // for std::sort, std::max, std::merge
18#include <cmath> // for std::abs, std::sqrt
19#include <stdexcept> // for std::invalid_argument
20#include <random> // for std::random_device
21
22// gudhi include
23#include <gudhi/read_persistence_from_file.h>
25#include <gudhi/Debug_utils.h>
26
27namespace Gudhi {
28namespace Persistence_representations {
29
65{
66 public:
76 Sliced_Wasserstein(const Persistence_diagram& diagram, double sigma = 1.0, int approx = 10)
77 : diagram_(diagram), approx_(approx), sigma_(sigma)
78 {
79 _build_rep();
80 }
81
88 double compute_scalar_product(const Sliced_Wasserstein& second) const
89 {
90 GUDHI_CHECK(this->sigma_ == second.sigma_,
91 std::invalid_argument("Error: different sigma values for representations"));
92 return std::exp(-_compute_sliced_wasserstein_distance(second) / (2 * this->sigma_ * this->sigma_));
93 }
94
101 double distance(const Sliced_Wasserstein& second) const
102 {
103 GUDHI_CHECK(this->sigma_ == second.sigma_,
104 std::invalid_argument("Error: different sigma values for representations"));
105 return std::sqrt(this->compute_scalar_product(*this) + second.compute_scalar_product(second) -
106 2 * this->compute_scalar_product(second));
107 }
108
109 private:
110 Persistence_diagram diagram_;
111 int approx_;
112 double sigma_;
113 std::vector<std::vector<double> > projections_, projections_diagonal_;
114
115 // **********************************
116 // Utils.
117 // **********************************
118
119 void _build_rep()
120 {
121 if (approx_ > 0) {
122 double step = pi / this->approx_;
123 int n = diagram_.size();
124
125 for (int i = 0; i < this->approx_; i++) {
126 std::vector<double> l, l_diag;
127 for (int j = 0; j < n; j++) {
128 double px = diagram_[j].first;
129 double py = diagram_[j].second;
130 double proj_diag = (px + py) / 2;
131
132 l.push_back(px * cos(-pi / 2 + i * step) + py * sin(-pi / 2 + i * step));
133 l_diag.push_back(proj_diag * cos(-pi / 2 + i * step) + proj_diag * sin(-pi / 2 + i * step));
134 }
135
136 std::sort(l.begin(), l.end());
137 std::sort(l_diag.begin(), l_diag.end());
138 projections_.push_back(std::move(l));
139 projections_diagonal_.push_back(std::move(l_diag));
140 }
141
142 diagram_.clear();
143 }
144 }
145
146 // Compute the angle formed by two points of a PD
147 double _compute_angle(const Persistence_diagram& diag, int i, int j) const
148 {
149 if (diag[i].second == diag[j].second)
150 return pi / 2;
151 else
152 return atan((diag[j].first - diag[i].first) / (diag[i].second - diag[j].second));
153 }
154
155 // Compute the integral of |cos()| between alpha and beta, valid only if alpha is in [-pi,pi] and beta-alpha is in
156 // [0,pi]
157 double _compute_int_cos(double alpha, double beta) const
158 {
159 double res = 0;
160 if (alpha >= 0 && alpha <= pi) {
161 if (cos(alpha) >= 0) {
162 if (pi / 2 <= beta) {
163 res = 2 - sin(alpha) - sin(beta);
164 } else {
165 res = sin(beta) - sin(alpha);
166 }
167 } else {
168 if (1.5 * pi <= beta) {
169 res = 2 + sin(alpha) + sin(beta);
170 } else {
171 res = sin(alpha) - sin(beta);
172 }
173 }
174 }
175 if (alpha >= -pi && alpha <= 0) {
176 if (cos(alpha) <= 0) {
177 if (-pi / 2 <= beta) {
178 res = 2 + sin(alpha) + sin(beta);
179 } else {
180 res = sin(alpha) - sin(beta);
181 }
182 } else {
183 if (pi / 2 <= beta) {
184 res = 2 - sin(alpha) - sin(beta);
185 } else {
186 res = sin(beta) - sin(alpha);
187 }
188 }
189 }
190 return res;
191 }
192
193 double _compute_int(double theta1,
194 double theta2,
195 int p,
196 int q,
197 const Persistence_diagram& diag1,
198 const Persistence_diagram& diag2) const
199 {
200 double norm = std::sqrt((diag1[p].first - diag2[q].first) * (diag1[p].first - diag2[q].first) +
201 (diag1[p].second - diag2[q].second) * (diag1[p].second - diag2[q].second));
202 double angle1;
203 if (diag1[p].first == diag2[q].first)
204 angle1 = theta1 - pi / 2;
205 else
206 angle1 = theta1 - atan((diag1[p].second - diag2[q].second) / (diag1[p].first - diag2[q].first));
207 double angle2 = angle1 + theta2 - theta1;
208 double integral = _compute_int_cos(angle1, angle2);
209 return norm * integral;
210 }
211
212 // Evaluation of the Sliced Wasserstein Distance between a pair of diagrams.
213 // TODO: decompose it in smaller methods if some modifications have to be done one day?
214 double _compute_sliced_wasserstein_distance(const Sliced_Wasserstein& second) const
215 {
216 GUDHI_CHECK(this->approx_ == second.approx_,
217 std::invalid_argument("Error: different approx values for representations"));
218
219 Persistence_diagram diagram1 = this->diagram_;
220 Persistence_diagram diagram2 = second.diagram_;
221 double sw = 0;
222
223 if (this->approx_ == -1) {
224 // Add projections onto diagonal.
225 int n1, n2;
226 n1 = diagram1.size();
227 n2 = diagram2.size();
228 double min_ordinate = std::numeric_limits<double>::max();
229 double min_abscissa = std::numeric_limits<double>::max();
230 double max_ordinate = std::numeric_limits<double>::lowest();
231 double max_abscissa = std::numeric_limits<double>::lowest();
232 for (int i = 0; i < n2; i++) {
233 min_ordinate = std::min(min_ordinate, diagram2[i].second);
234 min_abscissa = std::min(min_abscissa, diagram2[i].first);
235 max_ordinate = std::max(max_ordinate, diagram2[i].second);
236 max_abscissa = std::max(max_abscissa, diagram2[i].first);
237 diagram1.emplace_back((diagram2[i].first + diagram2[i].second) / 2,
238 (diagram2[i].first + diagram2[i].second) / 2);
239 }
240 for (int i = 0; i < n1; i++) {
241 min_ordinate = std::min(min_ordinate, diagram1[i].second);
242 min_abscissa = std::min(min_abscissa, diagram1[i].first);
243 max_ordinate = std::max(max_ordinate, diagram1[i].second);
244 max_abscissa = std::max(max_abscissa, diagram1[i].first);
245 diagram2.emplace_back((diagram1[i].first + diagram1[i].second) / 2,
246 (diagram1[i].first + diagram1[i].second) / 2);
247 }
248 int num_pts_dgm = diagram1.size();
249
250 // Slightly perturb the points so that the PDs are in generic positions.
251 double epsilon = 0.0001;
252 double thresh_y = (max_ordinate - min_ordinate) * epsilon;
253 double thresh_x = (max_abscissa - min_abscissa) * epsilon;
254 std::random_device rd;
255 std::default_random_engine re(rd());
256 std::uniform_real_distribution<double> uni(-1, 1);
257 for (int i = 0; i < num_pts_dgm; i++) {
258 double u = uni(re);
259 diagram1[i].first += u * thresh_x;
260 diagram1[i].second += u * thresh_y;
261 diagram2[i].first += u * thresh_x;
262 diagram2[i].second += u * thresh_y;
263 }
264
265 // Compute all angles in both PDs.
266 std::vector<std::pair<double, std::pair<int, int> > > angles1, angles2;
267 for (int i = 0; i < num_pts_dgm; i++) {
268 for (int j = i + 1; j < num_pts_dgm; j++) {
269 double theta1 = _compute_angle(diagram1, i, j);
270 double theta2 = _compute_angle(diagram2, i, j);
271 angles1.emplace_back(theta1, std::pair<int, int>(i, j));
272 angles2.emplace_back(theta2, std::pair<int, int>(i, j));
273 }
274 }
275
276 // Sort angles.
277 std::sort(angles1.begin(),
278 angles1.end(),
279 [](const std::pair<double, std::pair<int, int> >& p1,
280 const std::pair<double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
281 std::sort(angles2.begin(),
282 angles2.end(),
283 [](const std::pair<double, std::pair<int, int> >& p1,
284 const std::pair<double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
285
286 // Initialize orders of the points of both PDs (given by ordinates when theta = -pi/2).
287 std::vector<int> orderp1, orderp2;
288 for (int i = 0; i < num_pts_dgm; i++) {
289 orderp1.push_back(i);
290 orderp2.push_back(i);
291 }
292 std::sort(orderp1.begin(), orderp1.end(), [&](int i, int j) {
293 if (diagram1[i].second != diagram1[j].second)
294 return (diagram1[i].second < diagram1[j].second);
295 else
296 return (diagram1[i].first > diagram1[j].first);
297 });
298 std::sort(orderp2.begin(), orderp2.end(), [&](int i, int j) {
299 if (diagram2[i].second != diagram2[j].second)
300 return (diagram2[i].second < diagram2[j].second);
301 else
302 return (diagram2[i].first > diagram2[j].first);
303 });
304
305 // Find the inverses of the orders.
306 std::vector<int> order1(num_pts_dgm);
307 std::vector<int> order2(num_pts_dgm);
308 for (int i = 0; i < num_pts_dgm; i++) {
309 order1[orderp1[i]] = i;
310 order2[orderp2[i]] = i;
311 }
312
313 // Record all inversions of points in the orders as theta varies along the positive half-disk.
314 std::vector<std::vector<std::pair<int, double> > > anglePerm1(num_pts_dgm);
315 std::vector<std::vector<std::pair<int, double> > > anglePerm2(num_pts_dgm);
316
317 int m1 = angles1.size();
318 for (int i = 0; i < m1; i++) {
319 double theta = angles1[i].first;
320 int p = angles1[i].second.first;
321 int q = angles1[i].second.second;
322 anglePerm1[order1[p]].emplace_back(p, theta);
323 anglePerm1[order1[q]].emplace_back(q, theta);
324 int a = order1[p];
325 int b = order1[q];
326 order1[p] = b;
327 order1[q] = a;
328 }
329
330 int m2 = angles2.size();
331 for (int i = 0; i < m2; i++) {
332 double theta = angles2[i].first;
333 int p = angles2[i].second.first;
334 int q = angles2[i].second.second;
335 anglePerm2[order2[p]].emplace_back(p, theta);
336 anglePerm2[order2[q]].emplace_back(q, theta);
337 int a = order2[p];
338 int b = order2[q];
339 order2[p] = b;
340 order2[q] = a;
341 }
342
343 for (int i = 0; i < num_pts_dgm; i++) {
344 anglePerm1[order1[i]].emplace_back(i, pi / 2);
345 anglePerm2[order2[i]].emplace_back(i, pi / 2);
346 }
347
348 // Compute the SW distance with the list of inversions.
349 for (int i = 0; i < num_pts_dgm; i++) {
350 std::vector<std::pair<int, double> > u, v;
351 u = anglePerm1[i];
352 v = anglePerm2[i];
353 double theta1, theta2;
354 theta1 = -pi / 2;
355 unsigned int ku, kv;
356 ku = 0;
357 kv = 0;
358 theta2 = std::min(u[ku].second, v[kv].second);
359 while (theta1 != pi / 2) {
360 if (diagram1[u[ku].first].first != diagram2[v[kv].first].first ||
361 diagram1[u[ku].first].second != diagram2[v[kv].first].second)
362 if (theta1 != theta2) sw += _compute_int(theta1, theta2, u[ku].first, v[kv].first, diagram1, diagram2);
363 theta1 = theta2;
364 if ((theta2 == u[ku].second) && ku < u.size() - 1) ku++;
365 if ((theta2 == v[kv].second) && kv < v.size() - 1) kv++;
366 theta2 = std::min(u[ku].second, v[kv].second);
367 }
368 }
369 } else {
370 double step = pi / this->approx_;
371 std::vector<double> v1, v2;
372 for (int i = 0; i < this->approx_; i++) {
373 v1.clear();
374 v2.clear();
375 std::merge(this->projections_[i].begin(),
376 this->projections_[i].end(),
377 second.projections_diagonal_[i].begin(),
378 second.projections_diagonal_[i].end(),
379 std::back_inserter(v1));
380 std::merge(second.projections_[i].begin(),
381 second.projections_[i].end(),
382 this->projections_diagonal_[i].begin(),
383 this->projections_diagonal_[i].end(),
384 std::back_inserter(v2));
385
386 int n = v1.size();
387 double f = 0;
388 for (int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]);
389 sw += f * step;
390 }
391 }
392
393 return sw / pi;
394 }
395
396}; // class Sliced_Wasserstein
397
398} // namespace Persistence_representations
399} // namespace Gudhi
400
401#endif // SLICED_WASSERSTEIN_H_
Sliced_Wasserstein(const Persistence_diagram &diagram, double sigma=1.0, int approx=10)
Sliced Wasserstein kernel constructor.
Definition Sliced_Wasserstein.h:76
double compute_scalar_product(const Sliced_Wasserstein &second) const
Evaluation of the kernel on a pair of diagrams.
Definition Sliced_Wasserstein.h:88
double distance(const Sliced_Wasserstein &second) const
Evaluation of the distance between images of diagrams in the Hilbert space of the kernel.
Definition Sliced_Wasserstein.h:101
This file contain an implementation of some common procedures used in the Persistence_representations...
constexpr double pi
Definition common_persistence_representations.h:45
std::vector< std::pair< double, double > > Persistence_diagram
Definition common_persistence_representations.h:38
Gudhi namespace.
Definition SimplicialComplexForAlpha.h:14