Low-Level Abstraction of Memory Access
ArrayIndexRange.hpp
Go to the documentation of this file.
1 // Copyright 2022 Bernhard Manfred Gruber
2 // SPDX-License-Identifier: MPL-2.0
3 
4 #pragma once
5 
6 #include "ArrayExtents.hpp"
7 #include "Core.hpp"
8 #include "macros.hpp"
9 
10 #include <algorithm>
11 #include <iterator>
12 #include <limits>
13 #if CAN_USE_RANGES
14 # include <ranges>
15 #endif
16 
17 namespace llama
18 {
21  template<typename ArrayExtents>
23  {
24  static_assert(!std::is_const_v<ArrayExtents>);
25 
26  using value_type = typename ArrayExtents::Index;
27  using difference_type = std::ptrdiff_t;
30  using iterator_category = std::random_access_iterator_tag;
31 
32  static constexpr std::size_t rank = ArrayExtents::rank;
33 
34  constexpr ArrayIndexIterator() noexcept = default;
35 
37  : extents(extents)
38  , current(current)
39  {
40  }
41 
43  constexpr auto operator*() const noexcept -> value_type
44  {
45  return current;
46  }
47 
49  constexpr auto operator->() const noexcept -> pointer
50  {
51  return {**this};
52  }
53 
55  constexpr auto operator++() noexcept -> ArrayIndexIterator&
56  {
57  current[rank - 1]++;
58  for(auto i = static_cast<int>(rank) - 2; i >= 0; i--)
59  {
60  if(current[i + 1] != extents[i + 1])
61  return *this;
62  current[i + 1] = 0;
63  current[i]++;
64  }
65  return *this;
66  }
67 
69  constexpr auto operator++(int) noexcept -> ArrayIndexIterator
70  {
71  auto tmp = *this;
72  ++*this;
73  return tmp;
74  }
75 
77  constexpr auto operator--() noexcept -> ArrayIndexIterator&
78  {
79  current[rank - 1]--;
80  for(auto i = static_cast<int>(rank) - 2; i >= 0; i--)
81  {
82  // return if no underflow
83  if(current[i + 1] != static_cast<typename ArrayExtents::value_type>(-1))
84  return *this;
85  current[i + 1] = extents[i] - 1;
86  current[i]--;
87  }
88  // decrementing beyond [0, 0, ..., 0] is UB
89  return *this;
90  }
91 
93  constexpr auto operator--(int) noexcept -> ArrayIndexIterator
94  {
95  auto tmp = *this;
96  --*this;
97  return tmp;
98  }
99 
101  constexpr auto operator[](difference_type i) const noexcept -> reference
102  {
103  return *(*this + i);
104  }
105 
107  constexpr auto operator+=(difference_type n) noexcept -> ArrayIndexIterator&
108  {
109  // add n to all lower dimensions with carry
110  for(auto i = static_cast<int>(rank) - 1; i > 0 && n != 0; i--)
111  {
112  n += static_cast<difference_type>(current[i]);
113  const auto s = static_cast<difference_type>(extents[i]);
114  auto mod = n % s;
115  n /= s;
116  if(mod < 0)
117  {
118  mod += s;
119  n--;
120  }
121  current[i] = mod;
122  assert(current[i] < extents[i]);
123  }
124 
125  current[0] = static_cast<difference_type>(current[0]) + n;
126  // current is either within bounds or at the end ([last + 1, 0, 0, ..., 0])
127  assert(
128  (current[0] < extents[0]
129  || (current[0] == extents[0]
130  && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
131  && "Iterator was moved past the end");
132 
133  return *this;
134  }
135 
137  friend constexpr auto operator+(ArrayIndexIterator it, difference_type n) noexcept -> ArrayIndexIterator
138  {
139  it += n;
140  return it;
141  }
142 
144  friend constexpr auto operator+(difference_type n, ArrayIndexIterator it) noexcept -> ArrayIndexIterator
145  {
146  return it + n;
147  }
148 
150  constexpr auto operator-=(difference_type n) noexcept -> ArrayIndexIterator&
151  {
152  return operator+=(-n);
153  }
154 
156  friend constexpr auto operator-(ArrayIndexIterator it, difference_type n) noexcept -> ArrayIndexIterator
157  {
158  it -= n;
159  return it;
160  }
161 
163  friend constexpr auto operator-(const ArrayIndexIterator& a, const ArrayIndexIterator& b) noexcept
164  -> difference_type
165  {
166  assert(a.extents == b.extents);
167 
168  difference_type n = a.current[rank - 1] - b.current[rank - 1];
169  difference_type size = a.extents[rank - 1];
170  for(auto i = static_cast<int>(rank) - 2; i >= 0; i--)
171  {
172  n += (a.current[i] - b.current[i]) * size;
173  size *= a.extents[i];
174  }
175 
176  return n;
177  }
178 
180  friend constexpr auto operator==(
182  const ArrayIndexIterator<ArrayExtents>& b) noexcept -> bool
183  {
184  assert(a.extents == b.extents);
185  return a.current == b.current;
186  }
187 
189  friend constexpr auto operator!=(
191  const ArrayIndexIterator<ArrayExtents>& b) noexcept -> bool
192  {
193  return !(a == b);
194  }
195 
197  friend constexpr auto operator<(const ArrayIndexIterator& a, const ArrayIndexIterator& b) noexcept -> bool
198  {
199  assert(a.extents == b.extents);
200 #ifdef __NVCC__
201  // from: https://en.cppreference.com/w/cpp/algorithm/lexicographical_compare
202  auto first1 = std::begin(a.current);
203  auto last1 = std::end(a.current);
204  auto first2 = std::begin(b.current);
205  auto last2 = std::end(b.current);
206  for(; (first1 != last1) && (first2 != last2); ++first1, (void) ++first2)
207  {
208  if(*first1 < *first2)
209  return true;
210  if(*first2 < *first1)
211  return false;
212  }
213 
214  return (first1 == last1) && (first2 != last2);
215 #else
216  return std::lexicographical_compare(
217  std::begin(a.current),
218  std::end(a.current),
219  std::begin(b.current),
220  std::end(b.current));
221 #endif
222  }
223 
225  friend constexpr auto operator>(const ArrayIndexIterator& a, const ArrayIndexIterator& b) noexcept -> bool
226  {
227  return b < a;
228  }
229 
231  friend constexpr auto operator<=(const ArrayIndexIterator& a, const ArrayIndexIterator& b) noexcept -> bool
232  {
233  return !(a > b);
234  }
235 
237  friend constexpr auto operator>=(const ArrayIndexIterator& a, const ArrayIndexIterator& b) noexcept -> bool
238  {
239  return !(a < b);
240  }
241 
242  private:
243  ArrayExtents extents; // TODO(bgruber): we only need to store rank - 1 sizes
244  value_type current;
245  };
246 
249  template<typename ArrayExtents>
251  : private ArrayExtents
252 #if CAN_USE_RANGES
253  , std::ranges::view_base
254 #endif
255  {
256  static_assert(!std::is_const_v<ArrayExtents>);
257 
258  constexpr ArrayIndexRange() noexcept = default;
259 
261  constexpr explicit ArrayIndexRange(ArrayExtents extents) noexcept : ArrayExtents(extents)
262  {
263  }
264 
266  constexpr auto begin() const noexcept -> ArrayIndexIterator<ArrayExtents>
267  {
268  return {*this, typename ArrayExtents::Index{}};
269  }
270 
272  constexpr auto end() const noexcept -> ArrayIndexIterator<ArrayExtents>
273  {
274  auto endPos = typename ArrayExtents::Index{};
275  endPos[0] = this->toArray()[0];
276  return {*this, endPos};
277  }
278  };
279 } // namespace llama
#define LLAMA_EXPORT
Definition: macros.hpp:192
#define LLAMA_FN_HOST_ACC_INLINE
Definition: macros.hpp:96
constexpr auto toArray() const -> Index
static constexpr std::size_t rank
ArrayIndex< T, rank > Index
Iterator supporting ArrayIndexRange.
constexpr auto operator-=(difference_type n) noexcept -> ArrayIndexIterator &
constexpr friend auto operator+(ArrayIndexIterator it, difference_type n) noexcept -> ArrayIndexIterator
constexpr auto operator*() const noexcept -> value_type
constexpr auto operator--(int) noexcept -> ArrayIndexIterator
constexpr friend auto operator<(const ArrayIndexIterator &a, const ArrayIndexIterator &b) noexcept -> bool
constexpr friend auto operator>(const ArrayIndexIterator &a, const ArrayIndexIterator &b) noexcept -> bool
constexpr friend auto operator==(const ArrayIndexIterator< ArrayExtents > &a, const ArrayIndexIterator< ArrayExtents > &b) noexcept -> bool
static constexpr std::size_t rank
constexpr friend auto operator!=(const ArrayIndexIterator< ArrayExtents > &a, const ArrayIndexIterator< ArrayExtents > &b) noexcept -> bool
constexpr friend auto operator+(difference_type n, ArrayIndexIterator it) noexcept -> ArrayIndexIterator
constexpr auto operator++(int) noexcept -> ArrayIndexIterator
typename ArrayExtents::Index value_type
constexpr auto operator+=(difference_type n) noexcept -> ArrayIndexIterator &
constexpr auto operator[](difference_type i) const noexcept -> reference
constexpr friend auto operator-(ArrayIndexIterator it, difference_type n) noexcept -> ArrayIndexIterator
constexpr friend auto operator>=(const ArrayIndexIterator &a, const ArrayIndexIterator &b) noexcept -> bool
constexpr friend auto operator-(const ArrayIndexIterator &a, const ArrayIndexIterator &b) noexcept -> difference_type
constexpr auto operator++() noexcept -> ArrayIndexIterator &
std::random_access_iterator_tag iterator_category
constexpr auto operator--() noexcept -> ArrayIndexIterator &
constexpr ArrayIndexIterator() noexcept=default
constexpr friend auto operator<=(const ArrayIndexIterator &a, const ArrayIndexIterator &b) noexcept -> bool
constexpr auto operator->() const noexcept -> pointer
Range allowing to iterate over all indices in an ArrayExtents.
constexpr auto end() const noexcept -> ArrayIndexIterator< ArrayExtents >
constexpr auto begin() const noexcept -> ArrayIndexIterator< ArrayExtents >
constexpr ArrayIndexRange() noexcept=default