16 #include "LvArrayConfig.hpp" 21 #include <type_traits> 23 #if defined( LVARRAY_USE_CUDA ) 24 #include <cuda_fp16.h> 47 template<
typename T,
typename U >
57 template<
typename T >
66 template<
typename T >
78 template<
typename T >
88 template<
typename T >
99 template<
typename T >
102 {
return __hlt( x, y ); }
104 #if defined( LVARRAY_USE_CUDA ) 112 template<
typename U >
114 __half
convert( __half
const, U
const u )
115 {
return __float2half_rn( u ); }
124 template<
typename U >
126 __half2
convert( __half2
const, U
const u )
127 {
return __float2half2_rn( u ); }
135 __half2
convert( __half2
const, __half
const u )
137 #if defined( LVARRAY_DEVICE_COMPILE ) 138 return __half2half2( u );
140 return __float2half2_rn( u );
153 template<
typename U,
typename V >
155 __half2
convert( __half2
const, U
const u, V
const v )
156 {
return __floats2half2_rn( u, v ); }
165 __half2
convert( __half2
const, __half
const u, __half
const v )
167 #if defined( LVARRAY_DEVICE_COMPILE ) 168 return __halves2half2( u, v );
170 return __floats2half2_rn( u, v );
198 {
return __low2half( x ); }
206 {
return __high2half( x ); }
210 #if defined( LVARRAY_USE_DEVICE ) 217 __half
lessThan( __half
const x, __half
const y )
218 {
return __hlt( x, y ); }
226 __half2
lessThan( __half2
const x, __half2
const y )
227 {
return __hlt2( x, y ); }
242 template<
typename T >
251 template<
typename T >
262 template<
typename T,
typename U >
277 template<
typename T,
typename U,
typename V >
288 template<
typename T >
299 template<
typename T >
310 template<
typename T >
312 std::enable_if_t< std::is_arithmetic< T >::value, T >
313 max( T
const a, T
const b )
315 #if defined(LVARRAY_DEVICE_COMPILE) 322 #if defined( LVARRAY_USE_DEVICE ) 326 __half
max( __half
const a, __half
const b )
328 #if defined(LVARRAY_USE_CUDA) && CUDART_VERSION > 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) 329 return __hmax( a, b );
330 #elif defined(LVARRAY_USE_HIP) 331 return __hgt( a, b ) ? a : b;
333 return a > b ? a : b;
339 __half2
max( __half2
const a, __half2
const b )
341 #if defined(LVARRAY_USE_CUDA) && CUDART_VERSION > 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) 342 return __hmax2( a, b );
344 __half2
const aFactor = __hge2( a, b );
345 __half2
const bFactor = convert< __half2 >( 1 ) - aFactor;
346 return a * aFactor + bFactor * b;
359 template<
typename T >
361 std::enable_if_t< std::is_arithmetic< T >::value, T >
362 min( T
const a, T
const b )
364 #if defined(LVARRAY_DEVICE_COMPILE) 371 #if defined( LVARRAY_USE_CUDA ) 376 __half
min( __half
const a, __half
const b )
378 #if CUDART_VERSION > 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) 379 return __hmin( a, b );
381 return a < b ? a : b;
388 __half2
min( __half2
const a, __half2
const b )
390 #if CUDART_VERSION > 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) 391 return __hmin2( a, b );
393 __half2
const aFactor = __hle2( a, b );
394 __half2
const bFactor = convert< __half2 >( 1 ) - aFactor;
395 return a * aFactor + bFactor * b;
406 template<
typename T >
410 #if defined(LVARRAY_DEVICE_COMPILE) 417 #if defined( LVARRAY_USE_DEVICE ) 421 __half
abs( __half
const x )
423 #if CUDART_VERSION > 11000 426 return x > __half( 0 ) ? x : -x;
432 __half2
abs( __half2
const x )
434 #if CUDART_VERSION > 11000 437 return x - __hle2( x, convert< __half2 >( 0 ) ) * ( x + x );
448 template<
typename T >
469 #if defined(LVARRAY_DEVICE_COMPILE) 477 template<
typename T >
481 #if defined(LVARRAY_DEVICE_COMPILE) 488 #if defined( LVARRAY_USE_DEVICE ) 492 __half
sqrt( __half
const x )
493 { return ::hsqrt( x ); }
497 __half2
sqrt( __half2
const x )
498 { return ::h2sqrt( x ); }
511 #if defined(LVARRAY_DEVICE_COMPILE) 512 return ::rsqrtf( x );
519 template<
typename T >
523 #if defined( LVARRAY_DEVICE_COMPILE ) 524 return ::rsqrt(
double( x ) );
530 #if defined( LVARRAY_USE_DEVICE ) 534 __half
invSqrt( __half
const x )
535 { return ::hrsqrt( x ); }
539 __half2
invSqrt( __half2
const x )
540 { return ::h2rsqrt( x ); }
558 float sin(
float const theta )
560 #if defined(LVARRAY_DEVICE_COMPILE) 561 return ::sinf( theta );
568 template<
typename T >
570 double sin( T
const theta )
572 #if defined(LVARRAY_DEVICE_COMPILE) 579 #if defined( LVARRAY_USE_DEVICE ) 583 __half
sin( __half
const theta )
584 { return ::hsin( theta ); }
588 __half2
sin( __half2
const theta )
589 { return ::h2sin( theta ); }
600 float cos(
float const theta )
602 #if defined(LVARRAY_DEVICE_COMPILE) 603 return ::cosf( theta );
610 template<
typename T >
612 double cos( T
const theta )
614 #if defined(LVARRAY_DEVICE_COMPILE) 621 #if defined( LVARRAY_USE_DEVICE ) 625 __half
cos( __half
const theta )
626 { return ::hcos( theta ); }
630 __half2
cos( __half2
const theta )
631 { return ::h2cos( theta ); }
642 void sincos(
float const theta,
float & sinTheta,
float & cosTheta )
644 #if defined(LVARRAY_DEVICE_COMPILE) 645 #if defined(LVARRAY_USE_CUDA) 646 ::sincos( theta, &sinTheta, &cosTheta );
647 #elif defined(LVARRAY_USE_HIP) 648 ::sincosf( theta, &sinTheta, &cosTheta );
657 template<
typename T >
659 void sincos(
double const theta,
double & sinTheta,
double & cosTheta )
661 #if defined(LVARRAY_DEVICE_COMPILE) 662 ::sincos( theta, &sinTheta, &cosTheta );
670 template<
typename T >
672 void sincos( T
const theta,
double & sinTheta,
double & cosTheta )
674 #if defined(LVARRAY_DEVICE_COMPILE) 685 #if defined( LVARRAY_USE_DEVICE ) 689 void sincos( __half
const theta, __half & sinTheta, __half & cosTheta )
691 sinTheta = ::hsin( theta );
692 cosTheta = ::hcos( theta );
697 void sincos( __half2
const theta, __half2 & sinTheta, __half2 & cosTheta )
699 sinTheta = ::h2sin( theta );
700 cosTheta = ::h2cos( theta );
712 float tan(
float const theta )
714 #if defined(LVARRAY_DEVICE_COMPILE) 715 return ::tanf( theta );
722 template<
typename T >
724 double tan( T
const theta )
726 #if defined(LVARRAY_DEVICE_COMPILE) 733 #if defined( LVARRAY_USE_DEVICE ) 737 __half
tan( __half
const theta )
746 __half2
tan( __half2
const theta )
776 template<
typename T >
780 T
const negate =
lessThan( x, math::convert< T >( 0 ) );
781 T
const absX =
abs( x );
783 T ret = math::convert< T >( -0.0187293 ) * absX + math::convert< T >( 0.0742610 );
784 ret = ret * absX - math::convert< T >( 0.2121144 );
785 ret = ret * absX + math::convert< T >( 1.5707288 );
786 ret = math::convert< T >( 3.14159265358979 * 0.5 ) - ret *
sqrt( math::convert< T >( 1 ) - absX );
787 ret = ret - negate * ( ret + ret );
788 T
const smallAngle =
lessThan( absX, math::convert< T >( 1.7e-1 ) );
789 return smallAngle * x + ( math::convert< T >( 1 ) - smallAngle ) * ret;
798 template<
typename T >
802 T
const negate =
lessThan( x, math::convert< T >( 0 ) );
803 T
const absX =
abs( x );
805 T ret = math::convert< T >( -0.0187293 ) * absX + math::convert< T >( 0.0742610 );
806 ret = ret * absX - math::convert< T >( 0.2121144 );
807 ret = ret * absX + math::convert< T >( 1.5707288 );
808 ret = ret *
sqrt( math::convert< T >( 1 ) - absX );
809 ret = ret - negate * ( ret + ret );
810 return negate * math::convert< T >( 3.14159265358979 ) + ret;
820 template<
typename T >
825 T
const absX =
abs( x );
826 T
const absY =
abs( y );
827 T
const ratio =
min( absX, absY ) /
max( absX, absY );
828 T
const ratio2 = ratio * ratio;
830 T ret = math::convert< T >( -0.013480470 ) * ratio2 + math::convert< T >( 0.057477314 );
831 ret = ret * ratio2 - math::convert< T >( 0.121239071 );
832 ret = ret * ratio2 + math::convert< T >( 0.195635925 );
833 ret = ret * ratio2 - math::convert< T >( 0.332994597 );
834 ret = ret * ratio2 + math::convert< T >( 0.999995630 );
841 ret = internal::lessThan( absX, absY ) * ( math::convert< T >( 1.570796327 ) - ret - ret ) + ret;
842 ret = internal::lessThan( x, math::convert< T >( 0 ) ) * ( math::convert< T >( 3.141592654 ) - ret - ret ) + ret;
843 ret = ret - internal::lessThan( y, math::convert< T >( 0 ) ) * ( ret + ret );
859 #if defined(LVARRAY_DEVICE_COMPILE) 867 template<
typename T >
871 #if defined(LVARRAY_DEVICE_COMPILE) 878 #if defined( LVARRAY_USE_DEVICE ) 882 __half
asin( __half
const x )
883 {
return internal::asinImpl( x ); }
887 __half2
asin( __half2
const x )
888 {
return internal::asinImpl( x ); }
901 #if defined(LVARRAY_DEVICE_COMPILE) 909 template<
typename T >
913 #if defined(LVARRAY_DEVICE_COMPILE) 920 #if defined( LVARRAY_USE_DEVICE ) 924 __half
acos( __half
const x )
925 {
return internal::acosImpl( x ); }
929 __half2
acos( __half2
const x )
930 {
return internal::acosImpl( x ); }
942 float atan2(
float const y,
float const x )
944 #if defined(LVARRAY_DEVICE_COMPILE) 945 return ::atan2f( y, x );
952 template<
typename T >
954 double atan2( T
const y, T
const x )
956 #if defined(LVARRAY_DEVICE_COMPILE) 963 #if defined( LVARRAY_USE_CUDA ) 967 __half
atan2( __half
const y, __half
const x )
968 {
return internal::atan2Impl( y, x ); }
972 __half2
atan2( __half2
const y, __half2
const x )
973 {
return internal::atan2Impl( y, x ); }
991 float exp(
float const x )
993 #if defined(LVARRAY_DEVICE_COMPILE) 1001 template<
typename T >
1005 #if defined(LVARRAY_DEVICE_COMPILE) 1012 #if defined( LVARRAY_USE_DEVICE ) 1016 __half
exp( __half
const x )
1017 { return ::hexp( x ); }
1021 __half2
exp( __half2
const x )
1022 { return ::h2exp( x ); }
1035 #if defined(LVARRAY_DEVICE_COMPILE) 1043 template<
typename T >
1047 #if defined(LVARRAY_DEVICE_COMPILE) 1054 #if defined( LVARRAY_USE_DEVICE ) 1058 __half
log( __half
const x )
1059 { return ::hlog( x ); }
1063 __half2
log( __half2
const x )
1064 { return ::h2log( x ); }
LVARRAY_DEVICE LVARRAY_FORCE_INLINE SingleType< T > getSecond(T const x)
Definition: math.hpp:301
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE void sincos(float const theta, float &sinTheta, float &cosTheta)
Compute the sine and cosine of theta.
Definition: math.hpp:642
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double log(T const x)
Definition: math.hpp:1045
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE constexpr T square(T const x)
Definition: math.hpp:450
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float invSqrt(float const x)
Definition: math.hpp:509
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double atan2(T const y, T const x)
Definition: math.hpp:954
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float sin(float const theta)
Definition: math.hpp:558
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float exp(float const x)
Definition: math.hpp:991
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double exp(T const x)
Definition: math.hpp:1003
LVARRAY_DEVICE LVARRAY_FORCE_INLINE T atan2Impl(T const y, T const x)
Definition: math.hpp:823
LVARRAY_DEVICE LVARRAY_FORCE_INLINE T asinImpl(T const x)
Definition: math.hpp:778
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE constexpr T convert(U const u, V const v)
Convert u and v to a dual type.
Definition: math.hpp:279
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE constexpr std::enable_if_t< std::is_arithmetic< T >::value, T > min(T const a, T const b)
Definition: math.hpp:362
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE constexpr T convert(U const u)
Convert u to type.
Definition: math.hpp:264
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double cos(T const theta)
Definition: math.hpp:612
#define LVARRAY_FORCE_INLINE
Marks a function/lambda for inlining.
Definition: Macros.hpp:44
The type of a single value of type T.
Definition: math.hpp:67
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double sin(T const theta)
Definition: math.hpp:570
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double acos(T const x)
Definition: math.hpp:911
T type
An alias for T.
Definition: math.hpp:70
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE constexpr T abs(T const x)
Definition: math.hpp:408
#define LVARRAY_DEVICE
Mark a function for only device usage.
Definition: Macros.hpp:605
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float cos(float const theta)
Definition: math.hpp:600
LVARRAY_HOST_DEVICE constexpr T lessThan(T const x, T const y)
Definition: math.hpp:101
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float tan(float const theta)
Definition: math.hpp:712
The top level namespace.
Definition: Array.hpp:24
LVARRAY_DEVICE LVARRAY_FORCE_INLINE SingleType< T > getFirst(T const x)
Definition: math.hpp:290
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double sqrt(T const x)
Definition: math.hpp:479
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float asin(float const x)
Definition: math.hpp:857
Contains a bunch of macro definitions.
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE constexpr int numValues()
Return the number of values stored in type.
Definition: math.hpp:244
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float sqrt(float const x)
Definition: math.hpp:467
typename internal::SingleType< T >::type SingleType
The type of a single value of type T.
Definition: math.hpp:252
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double tan(T const theta)
Definition: math.hpp:724
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float atan2(float const y, float const x)
Definition: math.hpp:942
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float acos(float const x)
Definition: math.hpp:899
LVARRAY_DEVICE LVARRAY_FORCE_INLINE T acosImpl(T const x)
Definition: math.hpp:800
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE constexpr std::enable_if_t< std::is_arithmetic< T >::value, T > max(T const a, T const b)
Definition: math.hpp:313
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE float log(float const x)
Definition: math.hpp:1033
#define LVARRAY_HOST_DEVICE
Mark a function for both host and device usage.
Definition: Macros.hpp:600
LVARRAY_HOST_DEVICE LVARRAY_FORCE_INLINE double asin(T const x)
Definition: math.hpp:869