Loading...
Searching...
No Matches
simple_mdspan.h
Go to the documentation of this file.
1/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
2 * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
3 * Author(s): Hannah Schreiber, David Loiseaux
4 *
5 * Copyright (C) 2025 Inria
6 *
7 * Modification(s):
8 * - YYYY/MM Author: Description of the modification
9 */
10
16
17#ifndef GUDHI_SIMPLE_MDSPAN_H_
18#define GUDHI_SIMPLE_MDSPAN_H_
19
20#include <cstddef> // std::size_t
21#include <stdexcept>
22#include <type_traits> // std::remove_cv_t, std::make_unsigned_t, std::integral_constant
23#include <limits>
24#include <initializer_list>
25#include <utility>
26#include <array>
27
28#include <gudhi/Debug_utils.h>
29
30namespace Gudhi {
31
32inline constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max();
33
34template <class IndexType, std::size_t... Extents>
35class extents;
36
37namespace detail {
38
39 template <std::size_t v>
40 struct is_dynamic : std::integral_constant<std::size_t, 0> {};
41
42 template <>
43 struct is_dynamic<dynamic_extent> : std::integral_constant<std::size_t, 1> {};
44
45 template <std::size_t I, class T>
46 struct dynamic_count;
47
48 template <std::size_t I, std::size_t first, std::size_t... Tail>
49 struct dynamic_count<I, std::integer_sequence<std::size_t, first, Tail...> >
50 : std::integral_constant<std::size_t,
51 (is_dynamic<first>::value +
52 dynamic_count<I - 1, std::integer_sequence<std::size_t, Tail...>>::value)> {
53 };
54
55 template <std::size_t first, std::size_t... Tail>
56 struct dynamic_count<0, std::integer_sequence<std::size_t, first, Tail...> >
57 : std::integral_constant<std::size_t, is_dynamic<first>::value> {
58 };
59
60 template <std::size_t first>
61 struct dynamic_count<0, std::integer_sequence<std::size_t, first> >
62 : std::integral_constant<std::size_t, is_dynamic<first>::value> {};
63
64 template <std::size_t I, class T>
65 struct extent_value;
66
67 template <std::size_t I, std::size_t first, std::size_t... Tail>
68 struct extent_value<I, std::integer_sequence<std::size_t, first, Tail...>>
69 : extent_value<I - 1, std::integer_sequence<std::size_t, Tail...>> {};
70
71 template <std::size_t first, std::size_t... Tail>
72 struct extent_value<0, std::integer_sequence<std::size_t, first, Tail...>>
73 : std::integral_constant<std::size_t, first> {};
74
75 template <class T, T I, T N, T... integers>
76 struct dynamic_value_sequence {
77 using type = typename dynamic_value_sequence<T, I + 1, N, integers..., dynamic_extent>::type;
78 };
79
80 template <class T, T N, T... integers>
81 struct dynamic_value_sequence<T, N, N, integers...> {
82 using type = std::integer_sequence<T, integers...>;
83 };
84
85 template <class IndexType, std::size_t... Pack>
86 constexpr auto dynamic_value_extents(std::integer_sequence<std::size_t, Pack...>)
87 {
88 return extents<IndexType, Pack...>();
89 };
90
91 template <class IndexType, std::size_t Rank>
92 constexpr auto dynamic_value_extents_value =
93 dynamic_value_extents<IndexType>((typename Gudhi::detail::dynamic_value_sequence<std::size_t, 0, Rank>::type{}));
94
95} // namespace detail
96
97template <class IndexType, std::size_t Rank>
98using dextents = decltype(detail::dynamic_value_extents_value<IndexType, Rank>);
99
104template <class IndexType, std::size_t... Extents>
105class extents
106{
107 public:
108 using index_type = IndexType;
109 using size_type = std::make_unsigned_t<index_type>;
110 using rank_type = std::size_t;
111
112 // observers of the multidimensional index space
113 static constexpr rank_type rank() noexcept { return sizeof...(Extents); }
114
115 static constexpr rank_type rank_dynamic() noexcept
116 {
117 return detail::dynamic_count<rank() - 1, std::integer_sequence<std::size_t, Extents...> >::value;
118 }
119
120 static constexpr std::size_t static_extent(rank_type r) noexcept
121 {
122 std::array<std::size_t, sizeof...(Extents)> exts{Extents...};
123 return exts[r];
124 }
125
126 constexpr index_type extent(rank_type r) const noexcept
127 {
128 if (dynamic_extent_shifts_[r] < 0) return static_extent(r);
129 return dynamic_extents_[dynamic_extent_shifts_[r]];
130 }
131
132 void update_dynamic_extent(rank_type r, index_type i){
133 if (dynamic_extent_shifts_[r] < 0) throw std::invalid_argument("Given rank is not dynamic.");
134 dynamic_extents_[dynamic_extent_shifts_[r]] = i;
135 }
136
137 // constructors
138 constexpr extents() noexcept : dynamic_extents_(), dynamic_extent_shifts_(_init_shifts()) {}
139
140 template <class OtherIndexType, std::size_t... OtherExtents>
141 constexpr explicit extents(const extents<OtherIndexType, OtherExtents...>& other) noexcept
142 : dynamic_extents_(), dynamic_extent_shifts_(_init_shifts())
143 {
144 for (rank_type r = 0; r < rank(); ++r) {
145 if (dynamic_extent_shifts_[r] >= 0) dynamic_extents_[dynamic_extent_shifts_[r]] = other.extent(r);
146 }
147 }
148
149 template <class... OtherIndexTypes>
150 constexpr explicit extents(OtherIndexTypes... extents) noexcept
151 : dynamic_extents_{static_cast<IndexType>(extents)...}, dynamic_extent_shifts_(_init_shifts())
152 {}
153
154 template <class OtherIndexType, std::size_t N>
155 constexpr explicit extents(const std::array<OtherIndexType, N>& other) noexcept
156 : dynamic_extents_{other}, dynamic_extent_shifts_(_init_shifts())
157 {}
158
159 // comparison operators
160 template <class OtherIndexType, std::size_t... OtherExtents>
161 friend constexpr bool operator==(const extents& e1, const extents<OtherIndexType, OtherExtents...>& e2) noexcept
162 {
163 if (e1.rank() != e2.rank()) return false;
164 for (rank_type r = 0; r < rank(); ++r) {
165 if (e1.extent(r) != e2.extent(r)) return false;
166 }
167 return true;
168 }
169
170 friend void swap(extents& e1, extents& e2) noexcept
171 {
172 e1.dynamic_extents_.swap(e2.dynamic_extents_);
173 e1.dynamic_extent_shifts_.swap(e2.dynamic_extent_shifts_);
174 }
175
176 friend std::ostream &operator<<(std::ostream &stream, const extents &e)
177 {
178 stream << "[ " << sizeof...(Extents) << " ] ";
179 ((stream << Extents << ' '), ...);
180 stream << " [";
181 for (rank_type r = 0; r < e.rank(); ++r) stream << e.extent(r) << " ";
182 stream << "]";
183
184 return stream;
185 }
186
187 private:
188 std::array<index_type, rank_dynamic()> dynamic_extents_;
189 std::array<int, rank()> dynamic_extent_shifts_;
190
191 static constexpr std::array<int, rank()> _init_shifts()
192 {
193 std::array<std::size_t, sizeof...(Extents)> exts{Extents...};
194 std::array<int, rank()> res = {};
195 std::size_t index = 0;
196 for (rank_type i = 0; i < rank(); ++i) {
197 if (exts[i] == dynamic_extent) {
198 res[i] = index;
199 ++index;
200 } else {
201 res[i] = -1;
202 }
203 }
204 return res;
205 }
206};
207
208// Does not seem to work with C++17(?) because the use of 'dextents' is not explicit enough:
209// "trailing return type ‘Gudhi::dextents<long unsigned int, sizeof... (Integrals)>’ of deduction guide is not a
210// specialization of ‘Gudhi::extents<IndexType, Extents>’"
211// Or does someone knows a workaround...?
212// template<class... Integrals>
213// explicit extents(Integrals...) -> dextents<std::size_t, sizeof...(Integrals)>;
214
219class layout_right
220{
221 public:
222 template<class Extents>
223 class mapping
224 {
225 public:
226 using extents_type = Extents;
227 using index_type = typename extents_type::index_type;
228 using size_type = typename extents_type::size_type;
229 using rank_type = typename extents_type::rank_type;
230 using layout_type = layout_right;
231
232 // constructors
233 mapping() noexcept = default;
234 mapping(const mapping&) noexcept = default;
235
236 mapping(const extents_type& exts) noexcept : exts_(exts)
237 {
238 if constexpr (extents_type::rank() != 0) _initialize_strides();
239 }
240
241 mapping& operator=(const mapping&) noexcept = default;
242
243 // observers
244 constexpr const extents_type& extents() const noexcept { return exts_; }
245
246 index_type required_span_size() const noexcept
247 {
248 if constexpr (extents_type::rank() == 0) return 0;
249 else return ext_shifts_[0] * exts_.extent(0);
250 }
251
252 template <class... Indices>
253 constexpr index_type operator()(Indices... indices) const
254 {
255 return operator()({static_cast<index_type>(indices)...});
256 }
257
258 template <class IndexRange = std::initializer_list<index_type> >
259 constexpr index_type operator()(const IndexRange& indices) const
260 {
261 GUDHI_CHECK(indices.size() == extents_type::rank(), "Wrong number of parameters.");
262
263 index_type newIndex = 0;
264 auto it = indices.begin();
265 GUDHI_CHECK_code(unsigned int i = 0);
266 for (auto stride : ext_shifts_) {
267 GUDHI_CHECK_code(GUDHI_CHECK(*it < exts_.extent(i), "Out of bound index."));
268 newIndex += (stride * (*it));
269 ++it;
270 GUDHI_CHECK_code(++i);
271 }
272
273 return newIndex;
274 }
275
276 static constexpr bool is_always_unique() noexcept { return true; }
277
278 static constexpr bool is_always_exhaustive() noexcept { return true; }
279
280 static constexpr bool is_always_strided() noexcept { return true; }
281
282 static constexpr bool is_unique() noexcept { return true; }
283
284 static constexpr bool is_exhaustive() noexcept { return true; }
285
286 static constexpr bool is_strided() noexcept { return true; }
287
288 index_type stride(rank_type r) const
289 {
290 GUDHI_CHECK(r < ext_shifts_.size(), "Stride out of bound.");
291 return ext_shifts_[r];
292 }
293
294 friend bool operator==(const mapping& m1, const mapping& m2) noexcept { return m1.exts_ == m2.exts_; }
295
296 friend void swap(mapping& m1, mapping& m2) noexcept
297 {
298 swap(m1.exts_, m2.exts_);
299 m1.ext_shifts_.swap(m2.ext_shifts_);
300 }
301
302 // update can be faster than reconstructing everytime if only relatively small r's are updated.
303 void update_extent(rank_type r, index_type new_value)
304 {
305 GUDHI_CHECK(r < extents_type::rank(), "Index out of bound.");
306 exts_.update_dynamic_extent(r, new_value);
307 _update_strides(r);
308 }
309
310 private:
311 extents_type exts_;
312 std::array<index_type,extents_type::rank()> ext_shifts_;
313
314 constexpr void _initialize_strides()
315 {
316 ext_shifts_[extents_type::rank() - 1] = 1;
317 for (auto i = extents_type::rank() - 1; i > 0; --i) {
318 ext_shifts_[i - 1] = ext_shifts_[i] * exts_.extent(i);
319 }
320 }
321
322 constexpr void _update_strides(rank_type start)
323 {
324 for (auto i = start; i > 0; --i) {
325 ext_shifts_[i - 1] = ext_shifts_[i] * exts_.extent(i);
326 }
327 }
328 };
329};
330
345template <typename T, class Extents, class LayoutPolicy = layout_right>
346class Simple_mdspan
347{
348 public:
349 using layout_type = LayoutPolicy;
350 using mapping_type = typename LayoutPolicy::template mapping<Extents>;
351 using extents_type = Extents;
352 using element_type = T;
353 using value_type = std::remove_cv_t<T>;
354 using index_type = typename mapping_type::index_type;
355 using size_type = typename mapping_type::size_type;
356 using rank_type = typename mapping_type::rank_type;
357 using data_handle_type = T*;
358 using reference = T&;
359
360 Simple_mdspan() : ptr_(nullptr) {}
361
362 Simple_mdspan(const Simple_mdspan& rhs) = default;
363 Simple_mdspan(Simple_mdspan&& rhs) = default;
364
365 template <class... IndexTypes>
366 explicit Simple_mdspan(data_handle_type ptr, IndexTypes... exts)
367 : ptr_(ptr), map_(extents_type(exts...))
368 {
369 GUDHI_CHECK(ptr != nullptr || empty() || Extents::rank() == 0, "Given pointer is not properly initialized.");
370 }
371
372 template <class OtherIndexType, size_t N>
373 constexpr explicit Simple_mdspan(data_handle_type ptr, const std::array<OtherIndexType, N>& exts)
374 : ptr_(ptr), map_(extents_type(exts))
375 {
376 GUDHI_CHECK(ptr != nullptr || empty() || Extents::rank() == 0, "Given pointer is not properly initialized.");
377 }
378
379 Simple_mdspan(data_handle_type ptr, const mapping_type& m) : ptr_(ptr), map_(m) {}
380
381 Simple_mdspan& operator=(const Simple_mdspan& rhs) = default;
382 Simple_mdspan& operator=(Simple_mdspan&& rhs) = default;
383
384 // version with [] not possible before C++23
385 template <class... IndexTypes>
386 constexpr reference operator()(IndexTypes... indices) const
387 {
388 return operator[]({static_cast<index_type>(indices)...});
389 }
390
391 template <class IndexRange = std::initializer_list<index_type> >
392 reference operator[](const IndexRange& indices) const
393 {
394 return *(ptr_ + map_(indices));
395 }
396
397 constexpr rank_type rank() noexcept { return map_.extents().rank(); }
398
399 constexpr rank_type rank_dynamic() noexcept { return map_.extents().rank_dynamic(); }
400
401 static constexpr std::size_t static_extent(rank_type r) noexcept { return extents_type::static_extent(r); }
402
403 constexpr index_type extent(rank_type r) const
404 {
405 GUDHI_CHECK(r < map_.extents().rank(), "Out of bound index.");
406 return map_.extents().extent(r);
407 }
408
409 constexpr size_type size() const noexcept { return map_.required_span_size(); }
410
411 constexpr bool empty() const noexcept { return map_.required_span_size() == 0; }
412
413 constexpr index_type stride(rank_type r) const { return map_.stride(r); }
414
415 constexpr const extents_type& extents() const noexcept { return map_.extents(); }
416
417 constexpr const data_handle_type& data_handle() const noexcept { return ptr_; }
418
419 constexpr const mapping_type& mapping() const noexcept { return map_; }
420
421 // if is_unique() is true for all possible instantiations of this class
422 static constexpr bool is_always_unique() { return mapping_type::is_always_unique(); }
423
424 // if is_exhaustive() is true for all possible instantiations of this class
425 static constexpr bool is_always_exhaustive() { return mapping_type::is_always_exhaustive(); }
426
427 // if is_strided() is true for all possible instantiations of this class
428 static constexpr bool is_always_strided() { return mapping_type::is_always_strided(); }
429
430 // unicity of the mapping (i,j,k,...) -> real index
431 constexpr bool is_unique() const { return map_.is_unique(); }
432
433 // if all real indices have a preimage in form (i,j,k,...)
434 constexpr bool is_exhaustive() const { return map_.is_exhaustive(); }
435
436 // if distance in memory is constant between two values in same rank
437 constexpr bool is_strided() const { return map_.is_strided(); }
438
439 friend constexpr void swap(Simple_mdspan& x, Simple_mdspan& y) noexcept
440 {
441 std::swap(x.ptr_, y.ptr_);
442 swap(x.map_, y.map_);
443 }
444
445 // as not everything is computed at compile time as for mdspan, update is usually faster than reconstructing
446 // everytime.
447 void update_extent(rank_type r, index_type new_value) { map_.update_extent(r, new_value); }
448
449 // for update_extent to make sense, as resizing the vector can move it in the memory
450 void update_data(data_handle_type ptr)
451 {
452 GUDHI_CHECK(ptr != nullptr, "Null pointer not valid input.");
453 ptr_ = ptr;
454 }
455
456 private:
457 data_handle_type ptr_;
458 mapping_type map_;
459};
460
461template <class CArray>
462Simple_mdspan(CArray&)
463 -> Simple_mdspan<std::remove_all_extents_t<CArray>, Gudhi::extents<std::size_t, std::extent_v<CArray, 0>>>;
464
465template <class Pointer>
466Simple_mdspan(Pointer&&)
467 -> Simple_mdspan<std::remove_pointer_t<std::remove_reference_t<Pointer>>, Gudhi::extents<std::size_t>>;
468
469template <class ElementType, class... Integrals>
470explicit Simple_mdspan(ElementType*, Integrals...)
471 -> Simple_mdspan<ElementType, Gudhi::dextents<std::size_t, sizeof...(Integrals)>>;
472
473template <class ElementType, class OtherIndexType, std::size_t N>
474Simple_mdspan(ElementType*, const std::array<OtherIndexType, N>&)
475 -> Simple_mdspan<ElementType, Gudhi::dextents<std::size_t, N>>;
476
477template <class ElementType, class IndexType, std::size_t... ExtentsPack>
478Simple_mdspan(ElementType*, const Gudhi::extents<IndexType, ExtentsPack...>&)
479 -> Simple_mdspan<ElementType, Gudhi::extents<IndexType, ExtentsPack...>>;
480
481template <class ElementType, class MappingType>
482Simple_mdspan(ElementType*, const MappingType&)
483 -> Simple_mdspan<ElementType, typename MappingType::extents_type, typename MappingType::layout_type>;
484
485} // namespace Gudhi
486
487#endif // GUDHI_SIMPLE_MDSPAN_H_
Reproduces the behaviour of C++23 std::extents class.
Definition simple_mdspan.h:106
Gudhi namespace.
Definition SimplicialComplexForAlpha.h:14