1616
1717// ================================================================================
1818// this file has been auto-generated, do not modify its contents!
19- // date: 2024-11-18 16:57:58.817191
20- // git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
19+ // date: 2024-11-20 10:36:45.284577
20+ // git hash: 76501fda40df9e396998d11840bc8f10b11ea47b
2121// ================================================================================
2222
2323#ifndef KERNEL_FLOAT_MACROS_H
@@ -813,7 +813,7 @@ struct approx_level_policy {};
813813using approx_policy = approx_level_policy<>;
814814
815815#ifndef KERNEL_FLOAT_POLICY
816- #define KERNEL_FLOAT_POLICY accurate_policy;
816+ #define KERNEL_FLOAT_POLICY accurate_policy
817817#endif
818818
819819/* *
@@ -1448,6 +1448,9 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f")
14481448KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float , rsqrt, " rsqrt.approx.f32" , " f" )
14491449KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float , tanh, " tanh.approx.f32;" , " f" )
14501450
1451+ #define KERNEL_FLOAT_FAST_F32_MAP (F ) \
1452+ F (exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt)
1453+
14511454// KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f")
14521455// KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f")
14531456// KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")
@@ -1724,15 +1727,15 @@ using zip_common_type = vector<
17241727 * vec<float, 3> c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f]
17251728 * ```
17261729 */
1727- template <typename F, typename L, typename R>
1730+ template <typename Accuracy = default_policy, typename F, typename L, typename R>
17281731KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common (F fun, const L& left, const R& right) {
17291732 using T = promoted_vector_value_type<L, R>;
17301733 using O = result_t <F, T, T>;
17311734 using E = broadcast_vector_extent_type<L, R>;
17321735
17331736 vector_storage<O, extent_size<E>> result;
17341737
1735- detail::default_map_impl< F, extent_size<E>, O, T, T>::call (
1738+ detail::map_impl<Accuracy, F, extent_size<E>, O, T, T>::call (
17361739 fun,
17371740 result.data (),
17381741 detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call (
@@ -1745,10 +1748,17 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
17451748 return result;
17461749}
17471750
1748- #define KERNEL_FLOAT_DEFINE_BINARY_FUN (NAME ) \
1749- template <typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
1750- KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME (L&& left, R&& right) { \
1751- return zip_common (ops::NAME<C> {}, static_cast <L&&>(left), static_cast <R&&>(right)); \
1751+ #define KERNEL_FLOAT_DEFINE_BINARY_FUN (NAME ) \
1752+ template < \
1753+ typename Accuracy = default_policy, \
1754+ typename L, \
1755+ typename R, \
1756+ typename C = promoted_vector_value_type<L, R>> \
1757+ KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME (L&& left, R&& right) { \
1758+ return zip_common<Accuracy>( \
1759+ ops::NAME<C> {}, \
1760+ static_cast <L&&>(left), \
1761+ static_cast <R&&>(right)); \
17521762 }
17531763
17541764#define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR, EXPR_F64, EXPR_F32 ) \
@@ -3887,11 +3897,20 @@ struct vector: public S {
38873897 }
38883898
38893899 /* *
3890- * Returns the result of `* this + lhs * rhs`.
3900+ * Returns the result of `this + lhs * rhs`.
38913901 *
38923902 * The operation is performed using a single `kernel_float::fma` call, which may be faster then perform
38933903 * the addition and multiplication separately.
38943904 */
3905+ template <
3906+ typename L,
3907+ typename R,
3908+ typename T2 = promote_t <T, vector_value_type<L>, vector_value_type<R>>,
3909+ typename E2 = broadcast_extent<E, vector_extent_type<L>, vector_extent_type<R>>>
3910+ KERNEL_FLOAT_INLINE vector<T2, E2 > add_mul (const L& lhs, const R& rhs) const {
3911+ return ::kernel_float::fma (lhs, rhs, *this );
3912+ }
3913+
38953914 template <
38963915 typename L,
38973916 typename R,
@@ -4138,6 +4157,22 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
41384157 result[0 ] = r.x , result[1 ] = r.y ;
41394158 }
41404159};
4160+
4161+ // clang-format off
4162+ #define KERNEL_FLOAT_FAST_FP16_DISPATCH (OP ) \
4163+ template <size_t N> \
4164+ struct apply_impl <fast_policy, ops::OP<half_t >, N, half_t , half_t > { \
4165+ KERNEL_FLOAT_INLINE static void \
4166+ call (ops::OP<half_t >, half_t * output, const half_t * input) { \
4167+ float v[N]; \
4168+ map_impl<fast_policy, ops::cast<half_t , float >, N, float , half_t >::call ({}, v, input); \
4169+ map_impl<fast_policy, ops::OP<float >, N, float , float >::call ({}, v, v); \
4170+ map_impl<fast_policy, ops::cast<float , half_t >, N, half_t , float >::call ({}, output, v); \
4171+ } \
4172+ };
4173+ // clang-format on
4174+
4175+ KERNEL_FLOAT_FAST_F32_MAP (KERNEL_FLOAT_FAST_FP16_DISPATCH)
41414176} // namespace detail
41424177#endif
41434178
@@ -4390,6 +4425,22 @@ struct apply_impl<
43904425 result[0 ] = r.x , result[1 ] = r.y ;
43914426 }
43924427};
4428+
4429+ // clang-format off
4430+ #define KERNEL_FLOAT_FAST_BF16_DISPATCH (OP ) \
4431+ template <size_t N> \
4432+ struct apply_impl <fast_policy, ops::OP<bfloat16_t >, N, bfloat16_t , bfloat16_t > { \
4433+ KERNEL_FLOAT_INLINE static void \
4434+ call (ops::OP<bfloat16_t >, bfloat16_t * output, const bfloat16_t * input) { \
4435+ float v[N]; \
4436+ map_impl<fast_policy, ops::cast<bfloat16_t , float >, N, float , bfloat16_t >::call ({}, v, input); \
4437+ map_impl<fast_policy, ops::OP<float >, N, float , float >::call ({}, v, v); \
4438+ map_impl<fast_policy, ops::cast<float , bfloat16_t >, N, bfloat16_t , float >::call ({}, output, v); \
4439+ } \
4440+ };
4441+ // clang-format on
4442+
4443+ KERNEL_FLOAT_FAST_F32_MAP (KERNEL_FLOAT_FAST_BF16_DISPATCH)
43934444} // namespace detail
43944445#endif
43954446
@@ -4631,17 +4682,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) {
46314682
46324683template <int Iter>
46334684KERNEL_FLOAT_DEVICE half2_t rsqrt (half2_t x) {
4685+ // A small number added such that rsqrt(0) does not return NaN
4686+ static constexpr double EPS = 0.00000768899917602539 ;
4687+
46344688 // Set top and bottom bits for both halfs, then shift by 1, then invert
46354689 uint32_t r = ~((uint32_t (transmute<uint32_t >(x) >> 1 )) | ~uint32_t (0x3fff3fff ));
4636- // uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;
46374690
4638- // Add bias (0x199c)
4639- half2_t y = transmute<half2_t >(uint32_t (r) + uint32_t (0x199c199c ));
4691+ // Add bias
4692+ static constexpr uint32_t BIAS = 0x199c199c ;
4693+ half2_t y = transmute<half2_t >(uint32_t (r) + BIAS);
46404694
46414695 // Newton-Raphson iterations
46424696#pragma unroll
46434697 for (int i = 0 ; i < Iter; i++) {
4644- half2_t half_x = make_half2 (-0.5 ) * x ;
4698+ half2_t half_x = __hfma2 ( make_half2 (-0.5 ), x, make_half2 (-EPS)) ;
46454699 half2_t correction = __hfma2 (half_x, y * y, make_half2 (0.5 ));
46464700 y = __hfma2 (correction, y, y); // y += y * correction
46474701 }
@@ -4836,7 +4890,7 @@ template<int Level, typename F, typename T>
48364890struct apply_impl <approx_level_policy<Level>, F, 1 , T, T> {
48374891 KERNEL_FLOAT_INLINE static void call (F fun, T* output, const T* input) {
48384892 T in2[2 ], out2[2 ];
4839- out2 [0 ] = input[0 ];
4893+ in2 [0 ] = input[0 ];
48404894 apply_impl<approx_level_policy<Level>, F, 2 , T, T>::call (fun, out2, in2);
48414895 output[0 ] = out2[0 ];
48424896 }
@@ -4867,6 +4921,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
48674921KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , rcp, 1 )
48684922KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , exp, 0 )
48694923KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , log, 0 )
4924+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , asin, 2 )
4925+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , acos, 2 )
48704926#endif
48714927
48724928#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
@@ -4960,7 +5016,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
49605016#define KERNEL_FLOAT_FP8_CAST2 (T, FP8_TY, FP8_INTERP ) \
49615017 namespace detail { \
49625018 template <> \
4963- struct apply_impl <ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
5019+ struct apply_impl <accurate_policy, ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
49645020 KERNEL_FLOAT_INLINE static void call (ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
49655021 __half2_raw x; \
49665022 memcpy (&x, v, 2 * sizeof (T)); \
@@ -4969,7 +5025,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
49695025 } \
49705026 }; \
49715027 template <> \
4972- struct apply_impl <ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
5028+ struct apply_impl <accurate_policy, ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
49735029 KERNEL_FLOAT_INLINE static void call (ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
49745030 __nv_fp8x2_storage_t x; \
49755031 memcpy (&x, v, 2 * sizeof (FP8_TY)); \
@@ -4987,12 +5043,12 @@ KERNEL_FLOAT_FP8_CAST(double)
49875043
49885044
49895045namespace kernel_float {
4990- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e4m3)
4991- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e5m2)
5046+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e4m3)
5047+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e5m2)
49925048
4993- KERNEL_FLOAT_FP8_CAST (__half )
4994- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e4m3, __NV_E4M3)
4995- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e5m2, __NV_E5M2)
5049+ KERNEL_FLOAT_FP8_CAST (half_t )
5050+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e4m3, __NV_E4M3)
5051+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e5m2, __NV_E5M2)
49965052
49975053} // namespace kernel_float
49985054#endif // KERNEL_FLOAT_FP16_AVAILABLE
@@ -5001,12 +5057,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
50015057
50025058
50035059namespace kernel_float {
5004- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e4m3)
5005- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e5m2)
5060+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e4m3)
5061+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e5m2)
50065062
5007- KERNEL_FLOAT_FP8_CAST (__nv_bfloat16 )
5008- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e4m3, __NV_E4M3)
5009- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e5m2, __NV_E5M2)
5063+ KERNEL_FLOAT_FP8_CAST (bfloat16_t )
5064+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e4m3, __NV_E4M3)
5065+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e5m2, __NV_E5M2)
50105066} // namespace kernel_float
50115067#endif // KERNEL_FLOAT_BF16_AVAILABLE
50125068
@@ -5075,14 +5131,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double)
50755131KERNEL_FLOAT_TYPE_ALIAS (float64x, double )
50765132
50775133#if KERNEL_FLOAT_FP16_AVAILABLE
5078- KERNEL_FLOAT_TYPE_ALIAS (half, __half )
5079- KERNEL_FLOAT_TYPE_ALIAS(f16x, __half )
5080- KERNEL_FLOAT_TYPE_ALIAS(float16x, __half )
5134+ KERNEL_FLOAT_TYPE_ALIAS (half, half_t )
5135+ KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t )
5136+ KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t )
50815137#endif
50825138
50835139#if KERNEL_FLOAT_BF16_AVAILABLE
5084- KERNEL_FLOAT_TYPE_ALIAS (bfloat16x, __bfloat16 )
5085- KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16 )
5140+ KERNEL_FLOAT_TYPE_ALIAS (bfloat16x, bfloat16_t )
5141+ KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t )
50865142#endif
50875143
50885144#if KERNEL_FLOAT_BF8_AVAILABLE
0 commit comments