From f991a792b5ee8ba557a1df7e576009d79261e0ab Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Tue, 10 Dec 2024 16:12:54 +0800 Subject: [PATCH] Support bf16 for quant, layernorm and gemm Signed-off-by: Yuhong Guo --- kernels/csrc/attention/attention_generic.cuh | 65 +++++++ kernels/csrc/attention/dtype_fp8.cuh | 41 +++++ kernels/csrc/fused_kernels.cu | 49 ++--- kernels/csrc/layernorm_kernels.cu | 33 ++-- kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu | 134 +++++++++----- .../csrc/qgemm/w4a8_per_group/gemm_cuda.cu | 125 ++++++++----- kernels/csrc/qgemm/w8a8/w8a8_gemm_cuda.cu | 168 ++++++++++++------ kernels/csrc/utils.cuh | 88 ++++++++- 8 files changed, 521 insertions(+), 182 deletions(-) create mode 100644 kernels/csrc/attention/attention_generic.cuh create mode 100644 kernels/csrc/attention/dtype_fp8.cuh diff --git a/kernels/csrc/attention/attention_generic.cuh b/kernels/csrc/attention/attention_generic.cuh new file mode 100644 index 0000000..62409c0 --- /dev/null +++ b/kernels/csrc/attention/attention_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/kernels/csrc/attention/dtype_fp8.cuh b/kernels/csrc/attention/dtype_fp8.cuh new file mode 100644 index 0000000..e714e32 --- /dev/null +++ b/kernels/csrc/attention/dtype_fp8.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 + +namespace vllm { + +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> +struct Vec { + using Type = uint8_t; +}; + +template <> +struct Vec { + using Type = uint16_t; +}; + +template <> +struct Vec { + using Type = uint32_t; +}; + +template <> +struct Vec { + using Type = uint2; +}; + +} // namespace vllm diff --git a/kernels/csrc/fused_kernels.cu b/kernels/csrc/fused_kernels.cu index 4f5b810..1879499 100644 --- a/kernels/csrc/fused_kernels.cu +++ b/kernels/csrc/fused_kernels.cu @@ -27,11 +27,11 @@ __global__ void dequant_add_residual_kernel(const int32_t *__restrict__ input, for (int i = tid; i < hidden_size; i += blockDim.x) { if constexpr (use_per_token_dequant) { output[token_idx * hidden_size + i] = - (T)((((float)input[token_idx * hidden_size + i]) * __half2float(scale[token_idx])) + + (T)((((float)input[token_idx * hidden_size + i]) * to_float(scale[token_idx])) + (float)residual[token_idx * hidden_size + i]); } else { output[token_idx * hidden_size + i] = - (T)((((float)input[token_idx * hidden_size + i]) * __half2float(scale)) + + (T)((((float)input[token_idx * hidden_size + i]) * to_float(scale)) + (float)residual[token_idx * hidden_size + i]); } } @@ -39,13 +39,13 @@ __global__ void dequant_add_residual_kernel(const int32_t *__restrict__ input, template __global__ void dequant_kernel(const int32_t *__restrict__ input, - T *__restrict__ output, const at::Half scale, int m, + T *__restrict__ output, half scale, int m, int hidden_size, int input_stride, int out_stride) { const int tid = threadIdx.x; const int token_idx = blockIdx.x; for (int i = tid; i < hidden_size; i += blockDim.x) { output[token_idx * out_stride + i] = - (T)(((float)input[token_idx * input_stride + i]) * __half2float(scale)); + (T)(((float)input[token_idx * input_stride + i]) * to_float(scale)); } } @@ -71,7 +71,7 @@ __global__ void quant_kernel(const T *__restrict__ input, const float block_amax_val = blockReduceMax(amax_val); if (tid == 0) { s_amax = block_amax_val; - scale[token_idx] = __float2half_rn(block_amax_val / 127.0f); + scale[token_idx] = from_float(block_amax_val / 127.0f); } __syncthreads(); @@ -83,7 +83,7 @@ __global__ void quant_kernel(const T *__restrict__ input, } else { for (int i = tid; i < hidden_size; i += blockDim.x) { output[token_idx * hidden_size + i] = - float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / __half2float(scale)); + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / to_float(scale)); } } } @@ -91,10 +91,11 @@ __global__ void quant_kernel(const T *__restrict__ input, template __global__ void quant_kernel_fuse_sum(const T *__restrict__ input, - int8_t *__restrict__ output, - scale_type input_sum, - scale_type scale, - int num_tokens, int hidden_size) { + int8_t *__restrict__ output, + scale_type input_sum, + scale_type scale, + int num_tokens, + int hidden_size) { // TODO: get the sum here. const int tid = threadIdx.x; const int token_idx = blockIdx.x; @@ -118,8 +119,8 @@ __global__ void quant_kernel_fuse_sum(const T *__restrict__ input, const float block_sum_val = blockReduceSum(sum_val); if (tid == 0) { s_amax = block_amax_val; - scale[token_idx] = __float2half_rn(block_amax_val / 127.0f); - input_sum[token_idx] = __float2half_rn(block_sum_val); + scale[token_idx] = from_float(block_amax_val / 127.0f); + input_sum[token_idx] = from_float(block_sum_val); } __syncthreads(); @@ -131,7 +132,7 @@ __global__ void quant_kernel_fuse_sum(const T *__restrict__ input, } else { for (int i = tid; i < hidden_size; i += blockDim.x) { output[token_idx_mul_hidden_size + i] = - float_to_int8_rn(((float)input[token_idx_mul_hidden_size + i]) / __half2float(scale)); + float_to_int8_rn(((float)input[token_idx_mul_hidden_size + i]) / to_float(scale)); } } } @@ -205,8 +206,10 @@ void invoke_quant(torch::Tensor &out, // [..., hidden_size] dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { - vllm::quant_kernel<<>>( - input.data_ptr(), out.data_ptr(), scale, num_tokens, hidden_size); + // It look like this function is never called. We hard code the template to half since input argument scale is at::Half. + // using T = typename FloatTypeConverter::Type; + vllm::quant_kernel<<>>( + reinterpret_cast(input.data_ptr()), out.data_ptr(), scale, num_tokens, hidden_size); }); } @@ -221,9 +224,10 @@ void invoke_quant(torch::Tensor &out, // [..., hidden_size] dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { - vllm::quant_kernel<<>>( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), num_tokens, hidden_size); + using T = typename FloatTypeConverter::Type; + vllm::quant_kernel<<>>( + reinterpret_cast(input.data_ptr()), out.data_ptr(), + reinterpret_cast(scale.data_ptr()), num_tokens, hidden_size); }); } @@ -259,8 +263,9 @@ void invoke_quant_fuse_sum(torch::Tensor &out, // [..., hidden_size] dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel_fuse_sum", [&] { - vllm::quant_kernel_fuse_sum<<>>( - input.data_ptr(), out.data_ptr(), input_sum.data_ptr(), - scale.data_ptr(), num_tokens, hidden_size); + using T = typename FloatTypeConverter::Type; + vllm::quant_kernel_fuse_sum<<>>( + reinterpret_cast(input.data_ptr()), out.data_ptr(), reinterpret_cast(input_sum.data_ptr()), + reinterpret_cast(scale.data_ptr()), num_tokens, hidden_size); }); -} \ No newline at end of file +} diff --git a/kernels/csrc/layernorm_kernels.cu b/kernels/csrc/layernorm_kernels.cu index acd282d..fa14914 100644 --- a/kernels/csrc/layernorm_kernels.cu +++ b/kernels/csrc/layernorm_kernels.cu @@ -10,6 +10,7 @@ #include #include "dispatch_utils.h" + #include "utils.cuh" #include "reduction_utils.cuh" @@ -134,7 +135,7 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; const float_packed_t scale_orig_quant - = cuda_cast(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f); + = cuda_cast(with_per_tensor_scaling ? to_float(*scale_orig_quant_per_tensor) : 0.0f); T_scalar amax = 1e-6f; for (int i = tidx; i < n_elems; i += blockDim.x) @@ -270,9 +271,11 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; const float_packed_t scale_orig_quant - = cuda_cast(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f); + = cuda_cast(with_per_tensor_scaling ? to_float(*scale_orig_quant_per_tensor) : 0.0f); T_scalar amax = 1e-6f; - T_scalar sum = 0.0f; + // It is better to use float for sum. + // T_scalar sum = 0.0f; + float sum = 0.0f; for (int i = tidx; i < n_elems; i += blockDim.x) { @@ -303,7 +306,7 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const if (with_per_token_scaling) { float abs_max_f = blockAllReduceMax(cuda_cast(amax)); - float sum_f = blockAllReduceSum(cuda_cast(sum)); + float sum_f = blockAllReduceSum(sum); const float dynamic_per_token_scale = 127.f / abs_max_f; for (int i = tidx; i < n_elems; i += blockDim.x) { @@ -316,11 +319,13 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const reinterpret_cast(normed_output_quant)[index] = cuda_cast(val_f * cuda_cast(dynamic_per_token_scale)); + // For debug: output unquantized result. + // reinterpret_cast(normed_output_quant)[index] = cuda_cast(val_f); } if (tidx == 0) { scale_orig_quant_per_token[bidx] = abs_max_f / 127.f; - input_sum[bidx] = sum_f; + input_sum[bidx] = from_float(sum_f); } } } @@ -441,10 +446,10 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] using T = typename FloatTypeConverter::Type; if (use_per_token_quant) { // per-token - vllm::generalLayerNorm<<>>( + vllm::generalLayerNorm<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, - nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr(), + nullptr, epsilon, num_tokens, hidden_size, nullptr, reinterpret_cast(scaling.data_ptr()), out.data_ptr(), false ); // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale @@ -453,10 +458,10 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] // weight.data_ptr(), epsilon, num_tokens, hidden_size); } else { // per-tensor - vllm::generalLayerNorm<<>>( + vllm::generalLayerNorm<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, - nullptr, epsilon, num_tokens, hidden_size, scaling.data_ptr(), nullptr, + nullptr, epsilon, num_tokens, hidden_size, reinterpret_cast(scaling.data_ptr()), nullptr, out.data_ptr(), false ); } @@ -481,11 +486,11 @@ void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] using T = typename FloatTypeConverter::Type; if (use_per_token_quant) { // per-token - vllm::generalLayerNorm_fuse_sum<<>>( + vllm::generalLayerNorm_fuse_sum<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, - nullptr, epsilon, num_tokens, hidden_size, input_sum.data_ptr(), nullptr, scaling.data_ptr(), - out.data_ptr(), false + nullptr, epsilon, num_tokens, hidden_size, reinterpret_cast(input_sum.data_ptr()), + nullptr, reinterpret_cast(scaling.data_ptr()), out.data_ptr(), false ); // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale // normed_output_quant, use_shmem @@ -497,10 +502,10 @@ void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] // Not implemented per-tensor input_sum assert(false); - vllm::generalLayerNorm_fuse_sum<<>>( + vllm::generalLayerNorm_fuse_sum<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(weight.data_ptr()), nullptr, - nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr(), nullptr, + nullptr, epsilon, num_tokens, hidden_size, nullptr, reinterpret_cast(scaling.data_ptr()), nullptr, out.data_ptr(), false ); } diff --git a/kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu b/kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu index eb831fe..500d006 100644 --- a/kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu +++ b/kernels/csrc/qgemm/w4a8_per_chn/gemm_cuda.cu @@ -11,6 +11,9 @@ #include #include +#include "../../dispatch_utils.h" +#include "../../utils.cuh" + #define OP_M 16 #define OP_N 8 #define OP_K 32 @@ -37,7 +40,7 @@ printf("This kernel requires %d Bytes of shared memory, which exceeds " \ "device limit.\n", \ kSmemByteSize); \ - return ; \ + return ; \ } \ int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \ int num_blocks_n = num_out_channels / CTA_N / 1; \ @@ -47,11 +50,11 @@ (num_blocks_m + tile_shift - 1) / tile_shift); \ dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ auto kernel_func = \ - dense_kernel0; \ + dense_kernel0; \ cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, \ kSmemByteSize); \ - kernel_func<<>>( \ - in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats, num_in_feats, num_out_channels, \ + kernel_func<<>>( \ + in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats, num_in_feats, num_out_channels, \ num_in_channels); template @@ -127,8 +130,8 @@ cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) "n"(cp_size)); } -__device__ __inline__ void mma_m16n8k32(void *C_warp, void *A_shared_warp, - void *B_shared_warp) +__device__ __inline__ void mma_m16n8k32(void *C_warp, const void *A_shared_warp, + const void *B_shared_warp) { __asm__ __volatile__( "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32" @@ -148,7 +151,7 @@ __device__ __inline__ void mma_m16n8k32(void *C_warp, void *A_shared_warp, template __device__ __inline__ void -global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, +global_to_share_one_stage_A(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask, bool *preds) @@ -160,7 +163,7 @@ global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_A; int8_t *dst_hoisted = dst; - int8_t *src_hoisted = src + global_iter_k * CTA_K; + const int8_t *src_hoisted = src + global_iter_k * CTA_K; if (mask) { @@ -194,7 +197,7 @@ global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, template __device__ __inline__ void -global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, +global_to_share_one_stage_B(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { @@ -204,7 +207,7 @@ global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row; constexpr int kSmemCol = CTA_K; int8_t *dst_hoisted = dst; - int8_t *src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE; + const int8_t *src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE; #pragma unroll for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) @@ -228,7 +231,7 @@ global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, template -__device__ __inline__ void global_to_share_one_stage_zeros(int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +__device__ __inline__ void global_to_share_one_stage_zeros(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = CTA_N / PACK_SIZE / 1; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; @@ -256,7 +259,7 @@ __device__ __inline__ void global_to_share_one_stage_zeros(int8_t *src, int8_t * template __device__ __inline__ void -share_to_reg_one_stage_A(int8_t *src, int8_t *dst, int warp_offset_m, +share_to_reg_one_stage_A(const int8_t *src, int8_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) { constexpr int kSmemCol = CTA_K + SMEM_PAD_A; @@ -301,11 +304,11 @@ share_to_reg_one_stage_B(int8_t *src, int8_t *dst, int8_t *zeros, int8_t *scales } template -__global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, - half2 *__restrict__ wscales, half *__restrict__ ascales, - half2 *__restrict__ w_szs, half *__restrict__ a_ssums, - half *__restrict__ C, int M, int64_t N, int64_t K) + int STAGES, int G, typename T, typename T2> +__global__ void dense_kernel0(const int8_t *__restrict__ A, const int8_t *__restrict__ B, + const T2 *__restrict__ wscales, const T *__restrict__ ascales, + const T2 *__restrict__ w_szs, const T *__restrict__ a_ssums, + T *__restrict__ C, int M, int64_t N, int64_t K) { constexpr int SPLITK = 1; constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; @@ -383,9 +386,9 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE; - int8_t *A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + + const int8_t *A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE; - int8_t *B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + + const int8_t *B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE; @@ -578,14 +581,14 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, if (row_wb < M) { int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); - float2 wscale = __half22float2(*(wscales + col_wb / 2)); - float2 w_sz = __half22float2(*(w_szs + col_wb / 2)); - float ascale = __half2float(ascales[row_wb]); - float a_ssum = __half2float(a_ssums[row_wb]); + float2 wscale = to_float2(*(wscales + col_wb / 2)); + float2 w_sz = to_float2(*(w_szs + col_wb / 2)); + float ascale = to_float(ascales[row_wb]); + float a_ssum = to_float(a_ssums[row_wb]); float2 psums = make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1])); psums.x = psums.x * wscale.x * ascale - w_sz.x * a_ssum; psums.y = psums.y * wscale.y * ascale - w_sz.y * a_ssum; - *reinterpret_cast(C + row_wb * N + col_wb) = __float22half2_rn(psums); + *reinterpret_cast(C + row_wb * N + col_wb) = from_float2(psums); } }; } @@ -593,25 +596,24 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, } } -void gemm_forward_cuda(torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _wscales, - torch::Tensor _ascales, - torch::Tensor _w_szs, - torch::Tensor _a_ssums, - torch::Tensor _out_feats) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto w_szs = reinterpret_cast(_w_szs.data_ptr()); - auto a_ssums = reinterpret_cast(_a_ssums.data_ptr()); - auto wscales = reinterpret_cast(_wscales.data_ptr()); - auto ascales = reinterpret_cast(_ascales.data_ptr()); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); +template +void gemm_w4a8_per_chn(int num_in_feats, + int num_in_channels, + int num_out_feats, + int num_out_channels, + const int8_t* in_feats_ptr, + const int8_t* kernel_ptr, + const T* wscales_ptr, + const T* ascales, + const T* w_szs_ptr, + const T* a_ssums, + T* out_feats, + cudaStream_t stream) { + auto in_feats = reinterpret_cast(in_feats_ptr); + auto kernel = reinterpret_cast(kernel_ptr); + using T2 = typename packed_as::type; + auto wscales = reinterpret_cast(wscales_ptr); + auto w_szs = reinterpret_cast(w_szs_ptr); constexpr int G = 128; @@ -650,3 +652,49 @@ void gemm_forward_cuda(torch::Tensor _in_feats, } return ; } + +void gemm_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _wscales, + torch::Tensor _ascales, + torch::Tensor _w_szs, + torch::Tensor _a_ssums, + torch::Tensor _out_feats) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + bool result = false; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (_wscales.scalar_type() == at::kHalf) { + using scalar_t = at::Half; + using T = typename FloatTypeConverter::Type; + auto w_szs = reinterpret_cast(_w_szs.data_ptr()); + auto a_ssums = reinterpret_cast(_a_ssums.data_ptr()); + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + gemm_w4a8_per_chn( + num_in_feats, num_in_channels, num_out_feats, num_out_channels, + in_feats, kernel, wscales, ascales, + w_szs, a_ssums, out_feats, stream); + } else if (_wscales.scalar_type() == at::kBFloat16) { +#ifdef ENABLE_BF16 + using scalar_t = at::BFloat16; + using T = typename FloatTypeConverter::Type; + auto w_szs = reinterpret_cast(_w_szs.data_ptr()); + auto a_ssums = reinterpret_cast(_a_ssums.data_ptr()); + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + gemm_w4a8_per_chn( + num_in_feats, num_in_channels, num_out_feats, num_out_channels, + in_feats, kernel, wscales, ascales, + w_szs, a_ssums, out_feats, stream); +#endif + } +} diff --git a/kernels/csrc/qgemm/w4a8_per_group/gemm_cuda.cu b/kernels/csrc/qgemm/w4a8_per_group/gemm_cuda.cu index 24d6eaa..532a057 100644 --- a/kernels/csrc/qgemm/w4a8_per_group/gemm_cuda.cu +++ b/kernels/csrc/qgemm/w4a8_per_group/gemm_cuda.cu @@ -11,6 +11,9 @@ #include #include +#include "../../dispatch_utils.h" +#include "../../utils.cuh" + #define OP_M 16 #define OP_N 8 #define OP_K 32 @@ -47,10 +50,10 @@ (num_blocks_m + tile_shift - 1) / tile_shift); \ dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ auto kernel_func = \ - dense_kernel0; \ + dense_kernel0; \ cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, \ kSmemByteSize); \ - kernel_func<<>>( \ + kernel_func<<>>( \ in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats, num_in_feats, num_out_channels, \ num_in_channels); @@ -148,7 +151,7 @@ __device__ __inline__ void mma_m16n8k32(void *C_warp, void *A_shared_warp, template __device__ __inline__ void -global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, +global_to_share_one_stage_A(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask, bool *preds) @@ -160,7 +163,7 @@ global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_A; int8_t *dst_hoisted = dst; - int8_t *src_hoisted = src + global_iter_k * CTA_K; + const int8_t *src_hoisted = src + global_iter_k * CTA_K; if (mask) { @@ -190,7 +193,7 @@ global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, template __device__ __inline__ void -global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, +global_to_share_one_stage_B(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { @@ -200,7 +203,7 @@ global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row; constexpr int kSmemCol = CTA_K; int8_t *dst_hoisted = dst; - int8_t *src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE; + const int8_t *src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE; #pragma unroll for (int global_iter = 0; global_iter < total_global_iters; ++global_iter) @@ -223,7 +226,7 @@ global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, } template -__device__ __inline__ void global_to_share_one_stage_zeros(int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +__device__ __inline__ void global_to_share_one_stage_zeros(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = CTA_N / PACK_SIZE / 1; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; @@ -251,7 +254,7 @@ __device__ __inline__ void global_to_share_one_stage_zeros(int8_t *src, int8_t * template __device__ __inline__ void -share_to_reg_one_stage_A(int8_t *src, int8_t *dst, int warp_offset_m, +share_to_reg_one_stage_A(const int8_t *src, int8_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters) { constexpr int kSmemCol = CTA_K + SMEM_PAD_A; @@ -270,7 +273,7 @@ share_to_reg_one_stage_A(int8_t *src, int8_t *dst, int warp_offset_m, template __device__ __inline__ void -share_to_reg_one_stage_B(int8_t *src, int8_t *dst, int8_t *zeros, int8_t *scales_i8, +share_to_reg_one_stage_B(const int8_t *src, int8_t *dst, const int8_t *zeros, const int8_t *scales_i8, int warp_offset_m, int warp_offset_n, int k_0_0, int k_0_1, int shared_iters) { @@ -292,8 +295,8 @@ share_to_reg_one_stage_B(int8_t *src, int8_t *dst, int8_t *zeros, int8_t *scales auto ptr = (uint32_t *)dst + shared_iter * 8; int scales_zeros_offset = warp_offset_n + (threadIdx.x / 4) * 4 + shared_iter * 32; - uint32_t packed_scales = *reinterpret_cast(scales_i8 + scales_zeros_offset); - uint32_t packed_zeros = *reinterpret_cast(zeros + scales_zeros_offset); + uint32_t packed_scales = *reinterpret_cast(scales_i8 + scales_zeros_offset); + uint32_t packed_zeros = *reinterpret_cast(zeros + scales_zeros_offset); uint32_t scale_0 = packed_scales & 0xFF; uint32_t zero_point_0 = __byte_perm(packed_zeros, 0, 0x00000000); @@ -326,11 +329,11 @@ share_to_reg_one_stage_B(int8_t *src, int8_t *dst, int8_t *zeros, int8_t *scales } template -__global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, - int8_t *__restrict__ zeros, int8_t *__restrict__ scales_i8, - half2 *__restrict__ wscales, half *__restrict__ ascales, - half *__restrict__ C, int M, int64_t N, int64_t K) + int STAGES, int G, typename T, typename T2> +__global__ void dense_kernel0(const int8_t *__restrict__ A, const int8_t *__restrict__ B, + const int8_t *__restrict__ zeros, const int8_t *__restrict__ scales_i8, + const T2 *__restrict__ wscales, const T *__restrict__ ascales, + T *__restrict__ C, int M, int64_t N, int64_t K) { constexpr int SPLITK = 1; constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; @@ -408,9 +411,9 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE; - int8_t *A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + + const int8_t *A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE; - int8_t *B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + + const int8_t *B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE + (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE; @@ -614,12 +617,12 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, if (row_wb < M) { int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); - float2 wscale = __half22float2(*(wscales + col_wb / 2)); - float ascale = __half2float(ascales[row_wb]); + float2 wscale = to_float2(*(wscales + col_wb / 2)); + float ascale = to_float(ascales[row_wb]); float2 psums = make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1])); psums.x *= wscale.x * ascale; psums.y *= wscale.y * ascale; - *reinterpret_cast(C + row_wb * N + col_wb) = __float22half2_rn(psums); + *reinterpret_cast(C + row_wb * N + col_wb) = from_float2(psums); } }; } @@ -627,27 +630,21 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, } } -void gemm_forward_cuda(torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _zeros, - torch::Tensor _scales_i8, - torch::Tensor _wscales, - torch::Tensor _ascales, - torch::Tensor _out_feats) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - auto scales_i8 = reinterpret_cast(_scales_i8.data_ptr()); - auto wscales = reinterpret_cast(_wscales.data_ptr()); - auto ascales = reinterpret_cast(_ascales.data_ptr()); - auto options = - torch::TensorOptions().dtype(torch::kHalf).device(_in_feats.device()); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); +template +void gemm_w4a8_per_group(int num_in_feats, + int num_in_channels, + int num_out_feats, + int num_out_channels, + const int8_t* in_feats, + const int8_t* kernel, + const int8_t* zeros, + const int8_t* scales_i8, + const T* wscales_ptr, + const T* ascales, + T* out_feats, + cudaStream_t stream = nullptr) { + using T2 = typename packed_as::type; + auto wscales = reinterpret_cast(wscales_ptr); constexpr int G = 128; @@ -700,3 +697,47 @@ void gemm_forward_cuda(torch::Tensor _in_feats, } return ; } + + +void gemm_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _zeros, + torch::Tensor _scales_i8, + torch::Tensor _wscales, + torch::Tensor _ascales, + torch::Tensor _out_feats) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int8_t *in_feats = reinterpret_cast(_in_feats.data_ptr()); + int8_t *kernel = reinterpret_cast(_kernel.data_ptr()); + int8_t *zeros = reinterpret_cast(_zeros.data_ptr()); + int8_t *scales_i8 = reinterpret_cast(_scales_i8.data_ptr()); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + bool result = false; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (_wscales.scalar_type() == at::kHalf) { + using scalar_t = at::Half; + using T = typename FloatTypeConverter::Type; + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + gemm_w4a8_per_group( + num_in_feats, num_in_channels, num_out_feats, num_out_channels, + in_feats, kernel, zeros, scales_i8, + wscales, ascales, out_feats, stream); + } else if (_wscales.scalar_type() == at::kBFloat16) { +#ifdef ENABLE_BF16 + using scalar_t = at::BFloat16; + using T = typename FloatTypeConverter::Type; + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + gemm_w4a8_per_group( + num_in_feats, num_in_channels, num_out_feats, num_out_channels, + in_feats, kernel, zeros, scales_i8, + wscales, ascales, out_feats, stream); +#endif + } +} diff --git a/kernels/csrc/qgemm/w8a8/w8a8_gemm_cuda.cu b/kernels/csrc/qgemm/w8a8/w8a8_gemm_cuda.cu index de47acd..8cabc6f 100644 --- a/kernels/csrc/qgemm/w8a8/w8a8_gemm_cuda.cu +++ b/kernels/csrc/qgemm/w8a8/w8a8_gemm_cuda.cu @@ -11,6 +11,9 @@ #include #include +#include "../../dispatch_utils.h" +#include "../../utils.cuh" + #define OP_M 16 #define OP_N 8 #define OP_K 32 @@ -26,31 +29,31 @@ #else #define L2_CACHEHINT(size) #endif -#define KERNEL_LAUNCH_CODE \ +#define KERNEL_LAUNCH_CODE \ constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ - constexpr int kSmemByteSize = \ - (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B)) * STAGES * \ - sizeof(int8_t); \ - if (kSmemByteSize >= 99 * 1024) \ - { \ - printf("This kernel requires %d Bytes of shared memory, which exceeds " \ - "device limit.\n", \ - kSmemByteSize); \ - return ; \ - } \ - int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \ - int num_blocks_n = num_out_channels / CTA_N / 1; \ - const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \ - const int tile_shift = 1 << log_tile; \ - dim3 num_blocks(num_blocks_n *tile_shift, \ - (num_blocks_m + tile_shift - 1) / tile_shift); \ - dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ - auto kernel_func = \ - dense_kernel0; \ - cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, \ - kSmemByteSize); \ - kernel_func<<>>( \ - in_feats, kernel, wscales, ascales, out_feats, num_in_feats, num_out_channels, \ + constexpr int kSmemByteSize = \ + (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B)) * STAGES * \ + sizeof(int8_t); \ + if (kSmemByteSize >= 99 * 1024) \ + { \ + printf("This kernel requires %d Bytes of shared memory, which exceeds " \ + "device limit.\n", \ + kSmemByteSize); \ + return ; \ + } \ + int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \ + int num_blocks_n = num_out_channels / CTA_N / 1; \ + const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \ + const int tile_shift = 1 << log_tile; \ + dim3 num_blocks(num_blocks_n *tile_shift, \ + (num_blocks_m + tile_shift - 1) / tile_shift); \ + dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ + auto kernel_func = \ + dense_kernel0; \ + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, \ + kSmemByteSize); \ + kernel_func<<>>( \ + in_feats, kernel, wscales, ascales, out_feats, num_in_feats, num_out_channels,\ num_in_channels); template @@ -147,7 +150,7 @@ __device__ __inline__ void mma_m16n8k32(void *C_warp, void *A_shared_warp, template __device__ __inline__ void -global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, +global_to_share_one_stage_A(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask, bool *preds) @@ -158,8 +161,8 @@ global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_A; - int8_t *dst_hoisted = dst; - int8_t *src_hoisted = src + global_iter_k * CTA_K; + const int8_t *dst_hoisted = dst; + const int8_t *src_hoisted = src + global_iter_k * CTA_K; if (mask) { @@ -190,7 +193,7 @@ global_to_share_one_stage_A(int8_t *src, int8_t *dst, int global_ncols, template __device__ __inline__ void -global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, +global_to_share_one_stage_B(const int8_t *src, int8_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { @@ -201,7 +204,7 @@ global_to_share_one_stage_B(int8_t *src, int8_t *dst, int global_ncols, constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_B; int8_t *dst_hoisted = dst; - int8_t *src_hoisted = src + global_iter_k * CTA_K; + const int8_t *src_hoisted = src + global_iter_k * CTA_K; #pragma unroll for (int _global_iter = 0; _global_iter < partial_global_iters; @@ -265,10 +268,10 @@ share_to_reg_one_stage_B(int8_t *src, int8_t *dst, int warp_offset_m, } template -__global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, - half2 *__restrict__ wscales, half *__restrict__ ascales, - half *__restrict__ C, int M, int N, int K) + int STAGES, typename T, typename T2> +__global__ void dense_kernel0(const int8_t *__restrict__ A, const int8_t *__restrict__ B, + const T2 *__restrict__ wscales, const T *__restrict__ ascales, + T *__restrict__ C, int M, int N, int K) { constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K; @@ -336,9 +339,9 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, A_hoisted_col_swizzled * PACK_SIZE; int8_t *B_shared_hoisted = B_shared + B_hoisted_row * kSmemPadKB + B_hoisted_col_swizzled * PACK_SIZE; - int8_t *A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + + const int8_t *A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE; - int8_t *B_hoisted = B + cta_offset_n * K + B_hoisted_row * K + + const int8_t *B_hoisted = B + cta_offset_n * K + B_hoisted_row * K + B_hoisted_col * PACK_SIZE; bool A_g2s_preds[A_total_global_iters]; #pragma unroll @@ -516,12 +519,12 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, int row_wb = row_wb_1 + (local_id % 4) / 2 * 8; if (row_wb < M){ int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); - float2 wscale = __half22float2(*(wscales + col_wb / 2)); - float ascale = __half2float(ascales[row_wb]); + float2 wscale = to_float2(*(wscales + col_wb / 2)); + float ascale = to_float(ascales[row_wb]); float2 psums = make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1])); psums.x *= wscale.x * ascale; psums.y *= wscale.y * ascale; - *reinterpret_cast(C + row_wb * N + col_wb) = __float22half2_rn(psums); + *reinterpret_cast(C + row_wb * N + col_wb) = from_float2(psums); } }; } @@ -529,27 +532,23 @@ __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, } } -void w8a8_gemm_forward_cuda(torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _wscales, - torch::Tensor _ascales, - torch::Tensor _out_feats) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto wscales = reinterpret_cast(_wscales.data_ptr()); - auto ascales = reinterpret_cast(_ascales.data_ptr()); - // auto options = - // torch::TensorOptions().dtype(torch::kFloat16).device(_in_feats.device()); - // at::Tensor _out_feats = - // torch::empty({num_in_feats, _kernel.size(0)}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); +template +void gemm_w8a8_cuda(int num_in_feats, + int num_in_channels, + int num_out_feats, + int num_out_channels, + const int8_t* in_feats_ptr, + const int8_t* kernel_ptr, + const T* wscales_ptr, + const T* ascales_ptr, + T* out_ptr, + cudaStream_t stream) { + using T2 = typename packed_as::type; + auto in_feats = reinterpret_cast(in_feats_ptr); + auto kernel = reinterpret_cast(kernel_ptr); + auto wscales = reinterpret_cast(wscales_ptr); + auto ascales = reinterpret_cast(ascales_ptr); + auto out_feats = reinterpret_cast(out_ptr); if (num_out_feats > 128) { @@ -576,3 +575,54 @@ void w8a8_gemm_forward_cuda(torch::Tensor _in_feats, return ; } +void w8a8_gemm_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _wscales, + torch::Tensor _ascales, + torch::Tensor _out_feats) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + int8_t* in_feats = reinterpret_cast(_in_feats.data_ptr()); + int8_t* kernel = reinterpret_cast(_kernel.data_ptr()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (_wscales.scalar_type() == at::kHalf) { + using scalar_t = at::Half; + using T = typename FloatTypeConverter::Type; + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + return gemm_w8a8_cuda(num_in_feats, + num_in_channels, + num_out_feats, + num_out_channels, + in_feats, + kernel, + wscales, + ascales, + out_feats, + stream); + } else if (_wscales.scalar_type() == at::kBFloat16) { +#ifdef ENABLE_BF16 + using scalar_t = at::BFloat16; + using T = typename FloatTypeConverter::Type; + auto wscales = reinterpret_cast(_wscales.data_ptr()); + auto ascales = reinterpret_cast(_ascales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + return gemm_w8a8_cuda(num_in_feats, + num_in_channels, + num_out_feats, + num_out_channels, + in_feats, + kernel, + wscales, + ascales, + out_feats, + stream); +#endif + } +} diff --git a/kernels/csrc/utils.cuh b/kernels/csrc/utils.cuh index cd4d45d..7413959 100644 --- a/kernels/csrc/utils.cuh +++ b/kernels/csrc/utils.cuh @@ -6,7 +6,17 @@ #include #include #include -#include +#include +#include +#include +#include +#ifdef ENABLE_FP8 +#include "attention/dtype_fp8.cuh" +#endif +#if ENABLE_BF16 +#include +#endif + template struct FloatTypeConverter @@ -466,4 +476,78 @@ __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) } #endif -#endif // ENABLE_FP16 \ No newline at end of file +#endif // ENABLE_FP16 + + +template +__device__ inline float to_float(T t); + +template <> +__device__ inline float to_float(half t) { + return __half2float(t); +} + +template <> +__device__ inline float to_float(float t) { + return t; +} + +template +__device__ inline T from_float(float t);// { +// assert(false); +// return {}; +// } + +template <> +__device__ inline half from_float(float t) { + return __float2half_rn(t); +} + +template <> +__device__ inline float from_float(float t) { + return t; +} + +template +__device__ inline float2 to_float2(T t); + +template <> +__device__ inline float2 to_float2(half2 t) { + return __half22float2(t); +} + +template +__device__ inline T from_float2(float2 t); + +template <> +__device__ inline half2 from_float2(float2 t) { + return __float22half2_rn(t); +} + +#if ENABLE_BF16 +template <> +__device__ inline float to_float(nv_bfloat16 t) { + return __bfloat162float(t); +} + +template <> +__device__ inline nv_bfloat16 from_float(float t) { + return __float2bfloat16_rn(t); +} + +template <> +__device__ inline float2 to_float2(nv_bfloat162 t) { + return __bfloat1622float2(t); +} + +template <> +__device__ inline nv_bfloat162 from_float2(float2 t) { + return __float22bfloat162_rn(t); +} +#endif + + +template <> +__device__ inline float to_float(at::Half t) { + return __half2float(*reinterpret_cast(&t)); +}