LvArray
indexing.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, Lawrence Livermore National Security, LLC and LvArray contributors.
3  * All rights reserved.
4  * See the LICENSE file for details.
5  * SPDX-License-Identifier: (BSD-3-Clause)
6  */
7 
13 #pragma once
14 
15 // Source includes
16 #include "Macros.hpp"
17 #include "typeManipulation.hpp"
18 #include "limits.hpp"
19 
20 // TPL includes
21 #include <RAJA/RAJA.hpp>
22 
23 namespace LvArray
24 {
25 
29 namespace indexing
30 {
31 
36 template< bool B_IS_ONE >
38 {
46  template< typename A, typename B >
47  static inline LVARRAY_HOST_DEVICE constexpr auto multiply( A const a, B const b )
48  { return a * b; }
49 };
50 
54 template<>
55 struct ConditionalMultiply< true >
56 {
65  template< typename A, typename B >
66  static inline LVARRAY_HOST_DEVICE constexpr A multiply( A const a, B const & )
67  { return a; }
68 };
69 
76 template< int SIZE, typename T >
77 LVARRAY_HOST_DEVICE inline constexpr
78 std::enable_if_t< (SIZE == 1), T >
79 multiplyAll( T const * const LVARRAY_RESTRICT values )
80 { return values[ 0 ]; }
81 
88 template< int SIZE, typename T >
89 LVARRAY_HOST_DEVICE inline constexpr
90 std::enable_if_t< (SIZE > 1), T >
91 multiplyAll( T const * const LVARRAY_RESTRICT values )
92 { return values[ 0 ] * multiplyAll< SIZE - 1 >( values + 1 ); }
93 
104 template< int USD, typename INDEX_TYPE, typename INDEX >
105 LVARRAY_HOST_DEVICE inline constexpr
106 INDEX_TYPE getLinearIndex( INDEX_TYPE const * const LVARRAY_RESTRICT strides, INDEX const index )
107 { return ConditionalMultiply< USD == 0 >::multiply( index, strides[ 0 ] ); }
108 
121 template< int USD, typename INDEX_TYPE, typename INDEX, typename ... REMAINING_INDICES >
122 LVARRAY_HOST_DEVICE inline constexpr
123 INDEX_TYPE getLinearIndex( INDEX_TYPE const * const LVARRAY_RESTRICT strides, INDEX const index, REMAINING_INDICES const ... indices )
124 {
125  return ConditionalMultiply< USD == 0 >::multiply( index, strides[ 0 ] ) +
126  getLinearIndex< USD - 1, INDEX_TYPE, REMAINING_INDICES... >( strides + 1, indices ... );
127 }
128 
130 inline
131 std::string getIndexString()
132 { return "{}"; }
133 
141 template< typename INDEX, typename ... REMAINING_INDICES >
142 std::string getIndexString( INDEX const index, REMAINING_INDICES const ... indices )
143 {
144  std::ostringstream oss;
145 
146  oss << "{ " << index;
147  using expander = int[];
148  (void) expander{ 0, ( void (oss << ", " << indices ), 0 )... };
149  oss << " }";
150 
151  return oss.str();
152 }
153 
161 template< typename INDEX_TYPE, typename ... INDICES >
162 std::string printDimsAndIndices( INDEX_TYPE const * const LVARRAY_RESTRICT dims, INDICES const... indices )
163 {
164  constexpr int NDIM = sizeof ... (INDICES);
165  std::ostringstream oss;
166  oss << "dimensions = { " << dims[ 0 ];
167  for( int i = 1; i < NDIM; ++i )
168  {
169  oss << ", " << dims[ i ];
170  }
171 
172  oss << " } indices = " << getIndexString( indices ... );
173 
174  return oss.str();
175 }
176 
184 template< typename INDEX_TYPE, typename ... INDICES >
185 LVARRAY_HOST_DEVICE inline constexpr
186 bool invalidIndices( INDEX_TYPE const * const LVARRAY_RESTRICT dims, INDICES const ... indices )
187 {
188  int curDim = 0;
189  bool invalid = false;
190  typeManipulation::forEachArg( [dims, &curDim, &invalid]( auto const index )
191  {
192  invalid = invalid || ( index < 0 ) || ( index >= dims[ curDim ] );
193  ++curDim;
194  }, indices ... );
195 
196  return invalid;
197 }
198 
206 template< typename INDEX_TYPE, typename ... INDICES >
207 LVARRAY_HOST_DEVICE inline
208 void checkIndices( INDEX_TYPE const * const LVARRAY_RESTRICT dims, INDICES const ... indices )
209 { LVARRAY_ERROR_IF( invalidIndices( dims, indices ... ), "Invalid indices. " << printDimsAndIndices( dims, indices ... ) ); }
210 
220 template< typename PERMUTATION, typename INDEX_TYPE, camp::idx_t NDIM >
221 LVARRAY_HOST_DEVICE inline
223 {
225  INDEX_TYPE foldedStrides[ NDIM ];
226 
227  for( int i = 0; i < NDIM; ++i )
228  {
229  foldedStrides[ i ] = 1;
230  for( int j = i + 1; j < NDIM; ++j )
231  {
232  foldedStrides[ i ] *= dims[ perm[ j ] ];
233  }
234  }
235 
237  for( int i = 0; i < NDIM; ++i )
238  {
239  strides[ perm[ i ] ] = foldedStrides[ i ];
240  }
241 
242  return strides;
243 }
244 
245 } // namespace indexing
246 } // namespace LvArray
LVARRAY_HOST_DEVICE constexpr bool invalidIndices(INDEX_TYPE const *const LVARRAY_RESTRICT dims, INDICES const ... indices)
Definition: indexing.hpp:186
LVARRAY_HOST_DEVICE constexpr std::enable_if_t<(SIZE==1), T > multiplyAll(T const *const LVARRAY_RESTRICT values)
Definition: indexing.hpp:79
static LVARRAY_HOST_DEVICE constexpr A multiply(A const a, B const &)
Definition: indexing.hpp:66
std::string getIndexString()
Definition: indexing.hpp:131
LVARRAY_HOST_DEVICE constexpr CArray< camp::idx_t, sizeof...(INDICES) > asArray(camp::idx_seq< INDICES... >)
Definition: typeManipulation.hpp:549
#define LVARRAY_ERROR_IF(EXP, MSG)
Abort execution if EXP is true.
Definition: Macros.hpp:122
LVARRAY_HOST_DEVICE typeManipulation::CArray< INDEX_TYPE, NDIM > calculateStrides(typeManipulation::CArray< INDEX_TYPE, NDIM > const &dims)
Calculate the strides given the dimensions and permutation.
Definition: indexing.hpp:222
Contains templates useful for type manipulation.
Contains portable access to std::numeric_limits and functions for converting between integral types...
DISABLE_HD_WARNING constexpr LVARRAY_HOST_DEVICE void forEachArg(F &&f)
The recursive base case where no argument is provided.
Definition: typeManipulation.hpp:129
static LVARRAY_HOST_DEVICE constexpr auto multiply(A const a, B const b)
Definition: indexing.hpp:47
LVARRAY_HOST_DEVICE void checkIndices(INDEX_TYPE const *const LVARRAY_RESTRICT dims, INDICES const ... indices)
Check that the indices are with dims , if not the program is aborted.
Definition: indexing.hpp:208
A helper struct to multiply two numbers.
Definition: indexing.hpp:37
The top level namespace.
Definition: Array.hpp:24
std::string printDimsAndIndices(INDEX_TYPE const *const LVARRAY_RESTRICT dims, INDICES const ... indices)
Definition: indexing.hpp:162
Contains a bunch of macro definitions.
LVARRAY_HOST_DEVICE constexpr INDEX_TYPE getLinearIndex(INDEX_TYPE const *const LVARRAY_RESTRICT strides, INDEX const index)
Get the index into a one dimensional space.
Definition: indexing.hpp:106
#define LVARRAY_HOST_DEVICE
Mark a function for both host and device usage.
Definition: Macros.hpp:549