11#ifndef SLICED_WASSERSTEIN_H_
12#define SLICED_WASSERSTEIN_H_
15#include <gudhi/read_persistence_from_file.h>
16#include <gudhi/common_persistence_representations.h>
17#include <gudhi/Debug_utils.h>
27namespace Persistence_representations {
64 Persistence_diagram diagram;
67 std::vector<std::vector<double> > projections, projections_diagonal;
75 double step = pi / this->approx;
76 int n = diagram.size();
78 for (
int i = 0; i < this->approx; i++) {
79 std::vector<double> l, l_diag;
80 for (
int j = 0; j < n; j++) {
81 double px = diagram[j].first;
82 double py = diagram[j].second;
83 double proj_diag = (px + py) / 2;
85 l.push_back(px * cos(-pi / 2 + i * step) + py * sin(-pi / 2 + i * step));
86 l_diag.push_back(proj_diag * cos(-pi / 2 + i * step) + proj_diag * sin(-pi / 2 + i * step));
89 std::sort(l.begin(), l.end());
90 std::sort(l_diag.begin(), l_diag.end());
91 projections.push_back(std::move(l));
92 projections_diagonal.push_back(std::move(l_diag));
100 double compute_angle(
const Persistence_diagram& diag,
int i,
int j)
const {
101 if (diag[i].second == diag[j].second)
104 return atan((diag[j].first - diag[i].first) / (diag[i].second - diag[j].second));
109 double compute_int_cos(
double alpha,
double beta)
const {
111 if (alpha >= 0 && alpha <= pi) {
112 if (cos(alpha) >= 0) {
113 if (pi / 2 <= beta) {
114 res = 2 - sin(alpha) - sin(beta);
116 res = sin(beta) - sin(alpha);
119 if (1.5 * pi <= beta) {
120 res = 2 + sin(alpha) + sin(beta);
122 res = sin(alpha) - sin(beta);
126 if (alpha >= -pi && alpha <= 0) {
127 if (cos(alpha) <= 0) {
128 if (-pi / 2 <= beta) {
129 res = 2 + sin(alpha) + sin(beta);
131 res = sin(alpha) - sin(beta);
134 if (pi / 2 <= beta) {
135 res = 2 - sin(alpha) - sin(beta);
137 res = sin(beta) - sin(alpha);
144 double compute_int(
double theta1,
double theta2,
int p,
int q,
const Persistence_diagram& diag1,
145 const Persistence_diagram& diag2)
const {
146 double norm = std::sqrt((diag1[p].first - diag2[q].first) * (diag1[p].first - diag2[q].first) +
147 (diag1[p].second - diag2[q].second) * (diag1[p].second - diag2[q].second));
149 if (diag1[p].first == diag2[q].first)
150 angle1 = theta1 - pi / 2;
152 angle1 = theta1 - atan((diag1[p].second - diag2[q].second) / (diag1[p].first - diag2[q].first));
153 double angle2 = angle1 + theta2 - theta1;
154 double integral = compute_int_cos(angle1, angle2);
155 return norm * integral;
160 GUDHI_CHECK(this->approx == second.approx,
161 std::invalid_argument(
"Error: different approx values for representations"));
163 Persistence_diagram diagram1 = this->diagram;
164 Persistence_diagram diagram2 = second.diagram;
167 if (this->approx == -1) {
170 n1 = diagram1.size();
171 n2 = diagram2.size();
172 double min_ordinate = std::numeric_limits<double>::max();
173 double min_abscissa = std::numeric_limits<double>::max();
174 double max_ordinate = std::numeric_limits<double>::lowest();
175 double max_abscissa = std::numeric_limits<double>::lowest();
176 for (
int i = 0; i < n2; i++) {
177 min_ordinate = std::min(min_ordinate, diagram2[i].second);
178 min_abscissa = std::min(min_abscissa, diagram2[i].first);
179 max_ordinate = std::max(max_ordinate, diagram2[i].second);
180 max_abscissa = std::max(max_abscissa, diagram2[i].first);
181 diagram1.emplace_back((diagram2[i].first + diagram2[i].second) / 2,
182 (diagram2[i].first + diagram2[i].second) / 2);
184 for (
int i = 0; i < n1; i++) {
185 min_ordinate = std::min(min_ordinate, diagram1[i].second);
186 min_abscissa = std::min(min_abscissa, diagram1[i].first);
187 max_ordinate = std::max(max_ordinate, diagram1[i].second);
188 max_abscissa = std::max(max_abscissa, diagram1[i].first);
189 diagram2.emplace_back((diagram1[i].first + diagram1[i].second) / 2,
190 (diagram1[i].first + diagram1[i].second) / 2);
192 int num_pts_dgm = diagram1.size();
195 double epsilon = 0.0001;
196 double thresh_y = (max_ordinate - min_ordinate) * epsilon;
197 double thresh_x = (max_abscissa - min_abscissa) * epsilon;
198 std::random_device rd;
199 std::default_random_engine re(rd());
200 std::uniform_real_distribution<double> uni(-1, 1);
201 for (
int i = 0; i < num_pts_dgm; i++) {
203 diagram1[i].first += u * thresh_x;
204 diagram1[i].second += u * thresh_y;
205 diagram2[i].first += u * thresh_x;
206 diagram2[i].second += u * thresh_y;
210 std::vector<std::pair<double, std::pair<int, int> > > angles1, angles2;
211 for (
int i = 0; i < num_pts_dgm; i++) {
212 for (
int j = i + 1; j < num_pts_dgm; j++) {
213 double theta1 = compute_angle(diagram1, i, j);
214 double theta2 = compute_angle(diagram2, i, j);
215 angles1.emplace_back(theta1, std::pair<int, int>(i, j));
216 angles2.emplace_back(theta2, std::pair<int, int>(i, j));
221 std::sort(angles1.begin(), angles1.end(),
222 [](
const std::pair<
double, std::pair<int, int> >& p1,
223 const std::pair<
double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
224 std::sort(angles2.begin(), angles2.end(),
225 [](
const std::pair<
double, std::pair<int, int> >& p1,
226 const std::pair<
double, std::pair<int, int> >& p2) { return (p1.first < p2.first); });
229 std::vector<int> orderp1, orderp2;
230 for (
int i = 0; i < num_pts_dgm; i++) {
231 orderp1.push_back(i);
232 orderp2.push_back(i);
234 std::sort(orderp1.begin(), orderp1.end(), [&](
int i,
int j) {
235 if (diagram1[i].second != diagram1[j].second)
236 return (diagram1[i].second < diagram1[j].second);
238 return (diagram1[i].first > diagram1[j].first);
240 std::sort(orderp2.begin(), orderp2.end(), [&](
int i,
int j) {
241 if (diagram2[i].second != diagram2[j].second)
242 return (diagram2[i].second < diagram2[j].second);
244 return (diagram2[i].first > diagram2[j].first);
248 std::vector<int> order1(num_pts_dgm);
249 std::vector<int> order2(num_pts_dgm);
250 for (
int i = 0; i < num_pts_dgm; i++) {
251 order1[orderp1[i]] = i;
252 order2[orderp2[i]] = i;
256 std::vector<std::vector<std::pair<int, double> > > anglePerm1(num_pts_dgm);
257 std::vector<std::vector<std::pair<int, double> > > anglePerm2(num_pts_dgm);
259 int m1 = angles1.size();
260 for (
int i = 0; i < m1; i++) {
261 double theta = angles1[i].first;
262 int p = angles1[i].second.first;
263 int q = angles1[i].second.second;
264 anglePerm1[order1[p]].emplace_back(p, theta);
265 anglePerm1[order1[q]].emplace_back(q, theta);
272 int m2 = angles2.size();
273 for (
int i = 0; i < m2; i++) {
274 double theta = angles2[i].first;
275 int p = angles2[i].second.first;
276 int q = angles2[i].second.second;
277 anglePerm2[order2[p]].emplace_back(p, theta);
278 anglePerm2[order2[q]].emplace_back(q, theta);
285 for (
int i = 0; i < num_pts_dgm; i++) {
286 anglePerm1[order1[i]].emplace_back(i, pi / 2);
287 anglePerm2[order2[i]].emplace_back(i, pi / 2);
291 for (
int i = 0; i < num_pts_dgm; i++) {
292 std::vector<std::pair<int, double> > u, v;
295 double theta1, theta2;
300 theta2 = std::min(u[ku].second, v[kv].second);
301 while (theta1 != pi / 2) {
302 if (diagram1[u[ku].first].first != diagram2[v[kv].first].first ||
303 diagram1[u[ku].first].second != diagram2[v[kv].first].second)
304 if (theta1 != theta2) sw += compute_int(theta1, theta2, u[ku].first, v[kv].first, diagram1, diagram2);
306 if ((theta2 == u[ku].second) && ku < u.size() - 1) ku++;
307 if ((theta2 == v[kv].second) && kv < v.size() - 1) kv++;
308 theta2 = std::min(u[ku].second, v[kv].second);
312 double step = pi / this->approx;
313 std::vector<double> v1, v2;
314 for (
int i = 0; i < this->approx; i++) {
317 std::merge(this->projections[i].begin(), this->projections[i].end(), second.projections_diagonal[i].begin(),
318 second.projections_diagonal[i].end(), std::back_inserter(v1));
319 std::merge(second.projections[i].begin(), second.projections[i].end(), this->projections_diagonal[i].begin(),
320 this->projections_diagonal[i].end(), std::back_inserter(v2));
324 for (
int j = 0; j < n; j++) f += std::abs(v1[j] - v2[j]);
345 : diagram(_diagram), approx(_approx), sigma(_sigma) {
357 GUDHI_CHECK(this->sigma == second.sigma,
358 std::invalid_argument(
"Error: different sigma values for representations"));
359 return std::exp(-compute_sliced_wasserstein_distance(second) / (2 * this->sigma * this->sigma));
370 GUDHI_CHECK(this->sigma == second.sigma,
371 std::invalid_argument(
"Error: different sigma values for representations"));
373 2 * this->compute_scalar_product(second));
A class implementing the Sliced Wasserstein kernel.
Definition: Sliced_Wasserstein.h:62
Sliced_Wasserstein(const Persistence_diagram &_diagram, double _sigma=1.0, int _approx=10)
Sliced Wasserstein kernel constructor. , Real_valued_topological_data, Topological_data_with_scalar_p...
Definition: Sliced_Wasserstein.h:344
double compute_scalar_product(const Sliced_Wasserstein &second) const
Evaluation of the kernel on a pair of diagrams.
Definition: Sliced_Wasserstein.h:356
double distance(const Sliced_Wasserstein &second) const
Evaluation of the distance between images of diagrams in the Hilbert space of the kernel.
Definition: Sliced_Wasserstein.h:369