From 980c650b4b485e45198c232c30a069769ae48226 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Mon, 8 Sep 2025 19:51:24 +0800 Subject: [PATCH] Revert "supports dynamic Cfp8 (#3767)" This reverts commit af49b81ffd63484bc57694e43efb429ae6baf4ab. --- custom_ops/gpu_ops/append_attention.cu | 4 +- .../append_attn/append_attention_c8_impl.cuh | 302 ++++------ .../append_attn/append_attention_func.cuh | 199 +------ .../append_attn/append_attention_kernel.h | 5 +- .../decoder_write_cache_with_rope_impl.cuh | 288 ---------- .../decoder_write_cache_with_rope_kernel.cu | 64 +-- .../encoder_write_cache_with_rope_impl.cuh | 520 ++---------------- .../encoder_write_cache_with_rope_kernel.h | 4 +- .../append_attn/gqa_rope_write_cache.cu | 4 +- ...d_attention_c8_bfloat16_bfloat16_kernel.cu | 2 - ...append_attention_c8_bfloat16_fp8_kernel.cu | 2 - ...ppend_attention_c8_bfloat16_int8_kernel.cu | 2 - ...end_attention_c8_float16_float16_kernel.cu | 2 - .../append_attention_c8_float16_fp8_kerne.cu | 2 - .../append_attention_c8_float16_int8_kerne.cu | 2 - custom_ops/gpu_ops/append_attn/utils.cuh | 9 - .../layers/attention/append_attn_backend.py | 27 +- .../layers/quantization/kv_cache.py | 18 +- fastdeploy/worker/gpu_model_runner.py | 13 - tests/layers/test_append_attention.py | 169 +----- 20 files changed, 223 insertions(+), 1415 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 5e4ce35da7..6af601dadc 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -140,8 +140,8 @@ void AppendAttentionKernel( key_cache, value_cache, attn_mask, - cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales, - cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, cache_k_zp, cache_v_zp, out_linear_shifts, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index b2fe4c6f64..aa2b81b389 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -32,15 +32,14 @@ template + bool IsFP8=false> __global__ void multi_query_append_attention_c8_kernel( T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_k_scale, // [num_kv_heads] + const T *__restrict__ cache_v_scale, // [num_kv_heads] const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] const int *__restrict__ seq_lens, @@ -98,30 +97,28 @@ __global__ void multi_query_append_attention_c8_kernel( return; } - T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; - T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; - if constexpr (!IsDynamicC8) { - if constexpr (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; - } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + T cache_k_scale_reg[num_frags_y * 4]; + T cache_v_scale_reg[num_frags_y * 2]; + if (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; } const uint32_t q_end = @@ -210,13 +207,6 @@ __global__ void multi_query_append_attention_c8_kernel( smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - T* k_smem_scale = nullptr; - T* v_smem_scale = nullptr; - if constexpr (IsDynamicC8) { - k_smem_scale = reinterpret_cast(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); - v_smem_scale = k_smem_scale + num_frags_z * 16; - } const uint32_t num_iterations = div_up( @@ -298,22 +288,10 @@ __global__ void multi_query_append_attention_c8_kernel( #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - if constexpr (IsDynamicC8) { - produce_k_dynamic_scale( - k_smem_scale, - cache_k_scale_reg, - block_table_now, - cache_k_scale, - kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } wait_group<1>(); __syncthreads(); // s = qk - compute_qk_c8( + compute_qk_c8( &qo_smem, &q_smem_offset_r, &k_smem, @@ -346,7 +324,6 @@ __global__ void multi_query_append_attention_c8_kernel( s_frag, o_frag, m_frag, d_frag); __syncthreads(); - const int ori_kv_idx_base = kv_idx_base; kv_idx_base += num_frags_z * 16; produce_k_blockwise_c8( - v_smem_scale, - cache_v_scale_reg, - block_table_now, - cache_v_scale, - ori_kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } wait_group<1>(); __syncthreads(); @@ -387,9 +352,7 @@ __global__ void multi_query_append_attention_c8_kernel( BLOCK_SIZE, T, CacheT, - is_scale_channel_wise, - IsFP8, - IsDynamicC8>( + is_scale_channel_wise, IsFP8>( &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); __syncthreads(); @@ -506,15 +469,14 @@ template + bool IsFP8=false> __global__ void multi_query_append_attention_c8_warp1_4_kernel( T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] const int *__restrict__ seq_lens, @@ -572,30 +534,28 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( if (q_len <= 0) { return; } - T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; - T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; - if constexpr (!IsDynamicC8) { - if constexpr (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; - } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + T cache_k_scale_reg[num_frags_y * 4]; + T cache_v_scale_reg[num_frags_y * 2]; + if (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; } const uint32_t q_end = min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); @@ -686,13 +646,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - T* k_smem_scale = nullptr; - T* v_smem_scale = nullptr; - if constexpr (IsDynamicC8) { - k_smem_scale = reinterpret_cast(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); - v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16; - } const uint32_t num_iterations = div_up( CAUSAL @@ -775,23 +728,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - if constexpr (IsDynamicC8) { - produce_k_dynamic_scale( - k_smem_scale, - cache_k_scale_reg, - block_table_now, - cache_k_scale, - kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } wait_group<1>(); __syncthreads(); // s = qk - compute_qk_c8( + compute_qk_c8( &qo_smem, &q_smem_offset_r, &k_smem, @@ -824,7 +765,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( s_frag, o_frag, m_frag, d_frag); __syncthreads(); - const uint32_t ori_kv_idx_base = kv_idx_base; kv_idx_base += NUM_WARP_KV * num_frags_z * 16; produce_k_blockwise_c8( - v_smem_scale, - cache_v_scale_reg, - block_table_now, - cache_v_scale, - ori_kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } wait_group<1>(); __syncthreads(); @@ -865,9 +793,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( BLOCK_SIZE, T, CacheT, - is_scale_channel_wise, - IsFP8, - IsDynamicC8>( + is_scale_channel_wise, IsFP8>( &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); __syncthreads(); @@ -981,8 +907,7 @@ template + bool IsFP8=false> void MultiQueryAppendC8Attention( const AppendAttnMetaData &meta_data, const paddle::Tensor &qkv, @@ -1040,8 +965,7 @@ void MultiQueryAppendC8Attention( constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; constexpr uint32_t smem_size = num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + - num_frags_z * 16 * sizeof(T) * 2; + num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; auto split_kv_kernel = multi_query_append_attention_c8_kernel; + false, IsFP8>; if (is_scale_channel_wise) { split_kv_kernel = multi_query_append_attention_c8_kernel; + true, IsFP8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, @@ -1114,9 +1034,7 @@ void MultiQueryAppendC8Attention( num_frags_y, OUT_NV_TYPE, ENABLE_PREFILL, - false, - IsFP8, - IsDynamicC8>; + false, IsFP8>; if (is_scale_channel_wise) { nosplit_kv_kernel = multi_query_append_attention_c8_kernel; + true, IsFP8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, @@ -1316,8 +1232,7 @@ void MultiQueryAppendC8Attention( constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; constexpr uint32_t smem_size = num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + - NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; auto split_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel; + false, IsFP8>; if (is_scale_channel_wise) { split_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel; + true, IsFP8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, @@ -1398,9 +1309,7 @@ void MultiQueryAppendC8Attention( num_frags_y, OUT_NV_TYPE, ENABLE_PREFILL, - false, - IsFP8, - IsDynamicC8>; + false, IsFP8>; if (is_scale_channel_wise) { nosplit_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel; + true, IsFP8>; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, @@ -1655,7 +1562,6 @@ void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out) { const auto token_num = meta_data.token_nums; @@ -1664,7 +1570,6 @@ void CascadeAppendAttentionC8Kernel( const auto num_heads = meta_data.q_num_heads; const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; const auto head_dim = meta_data.head_dims; - bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8"; DISPATCH_CAUSAL( causal, @@ -1683,46 +1588,43 @@ void CascadeAppendAttentionC8Kernel( BLOCK_SIZE, {DISPATCH_BLOCKSHAPE_Q( block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, { - DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, { - MultiQueryAppendC8Attention( - meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale.get(), - cache_v_scale.get(), - shift_bias, - smooth_weight, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - is_decoder, - stream, - out); - })})})})})})}) + MultiQueryAppendC8Attention( + meta_data, + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale.get(), + cache_v_scale.get(), + shift_bias, + smooth_weight, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + is_decoder, + stream, + out); + })})})})})}) } diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 24787e8b72..146d0c30ad 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -384,113 +384,6 @@ __device__ __forceinline__ void produce_v_blockwise_c8( } } -template -__device__ __forceinline__ void produce_k_dynamic_scale( - T* k_smem_scale, - T* cache_k_reg, - const int* block_table_now, - const T* cache_k_scale, - const uint32_t kv_idx, - const uint32_t kv_num_heads, - const uint32_t kv_head_idx, - const uint32_t chunk_end -) { - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - if constexpr (NUM_WARP_Q == 4) { - // 4 warps shared block_size - const uint32_t tid = ty * 32 + tx; - int block_id = __ldg(&block_table_now[kv_idx / block_size]); - if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; - if (tid < block_size) { - k_smem_scale[tid] = cache_k_scale_now[tid]; - } - __syncthreads(); - const uint32_t row_id = tx / 4; - for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id]; - cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8]; - } - } else { - // 1 warp 32 tokens - const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; - int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); - if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; - const int kv_idx_this_thread = kv_idx + ty * 32 + tx; - if (kv_idx_this_thread < chunk_end) { - k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx]; - } else { - k_smem_scale[ty * 32 + tx] = 0; - } - __syncwarp(); - const uint32_t row_id = tx / 4; - for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id]; - cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8]; - } - } -} - -template -__device__ __forceinline__ void produce_v_dynamic_scale( - T* v_smem_scale, - T* cache_v_reg, - const int* block_table_now, - const T* cache_v_scale, - const uint32_t kv_idx, - const uint32_t kv_num_heads, - const uint32_t kv_head_idx, - const uint32_t chunk_end -) { - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - - if constexpr (NUM_WARP_Q == 4) { - // 4 warps shared block_size - const uint32_t tid = ty * 32 + tx; - int block_id = __ldg(&block_table_now[kv_idx / block_size]); - if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; - if (tid < block_size) { - v_smem_scale[tid] = cache_v_scale_now[tid]; - } - __syncthreads(); - const uint32_t row_id = tx % 4 * 2; - for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id]; - cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1]; - cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8]; - cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9]; - } - } else { - // 1 warp 32 tokens - const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; - int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); - if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; - const int kv_idx_this_thread = kv_idx + ty * 32 + tx; - if (kv_idx_this_thread < chunk_end) { - v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx]; - } else { - v_smem_scale[ty * 32 + tx] = 0; - } - __syncwarp(); - const uint32_t row_id = tx % 4 * 2; - for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id]; - cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1]; - cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8]; - cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9]; - } - } -} - template + bool IsFP8=false> __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, @@ -968,27 +860,20 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, convert_c8(b_frag_dq_T, b_frag[fy * 2]); convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); // scale zp - if constexpr (!IsDynamicC8) { - if constexpr (is_scale_channel_wise) { - const int scale_col = (ky * 2 + fy) * 4; - b_frag_dq_T[0] *= cache_k_scale[scale_col]; - b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; - b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; - b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; - b_frag_dq_T[4] *= cache_k_scale[scale_col]; - b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; - b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; - b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; - } else { -#pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_k_scale[0]; - } - } + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; } else { #pragma unroll for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; + b_frag_dq_T[b_i] *= cache_k_scale[0]; } } #pragma unroll @@ -1208,9 +1093,7 @@ template + bool is_scale_channel_wise = false, bool IsFP8=false> __device__ __forceinline__ void compute_sfm_v_c8( smem_t* v_smem, uint32_t* v_smem_offset_r, @@ -1252,28 +1135,16 @@ __device__ __forceinline__ void compute_sfm_v_c8( convert_c8(b_frag_dq_T, b_frag[fz * 2]); convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp - if constexpr (!IsDynamicC8) { - if constexpr (is_scale_channel_wise) { -#pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; - } - } else { + if constexpr (is_scale_channel_wise) { #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[0]; - } + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; } } else { - const int scale_col = (kz * 2 + fz) * 4; - b_frag_dq_T[0] *= cache_v_scale[scale_col]; - b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; - b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; - b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; - b_frag_dq_T[4] *= cache_v_scale[scale_col]; - b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; - b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; - b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 @@ -1300,9 +1171,7 @@ template + bool is_scale_channel_wise = false, bool IsFP8=false> __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( smem_t* v_smem, uint32_t* v_smem_offset_r, @@ -1346,28 +1215,16 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( convert_c8(b_frag_dq_T, b_frag[fz * 2]); convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp - if constexpr (!IsDynamicC8) { - if constexpr (is_scale_channel_wise) { + if constexpr (is_scale_channel_wise) { #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; - } - } else { - #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[0]; - } + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; } } else { - const int scale_col = (kz * 2 + fz) * 4; - b_frag_dq_T[0] *= cache_v_scale[scale_col]; - b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; - b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; - b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; - b_frag_dq_T[4] *= cache_v_scale[scale_col]; - b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; - b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; - b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; + #pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index 2cc0695928..8799c0a705 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -103,7 +103,6 @@ void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -265,10 +264,9 @@ void CascadeAppendAttentionKernel( causal, is_decoder, enable_prefill, - cache_quant_type_str, stream, out); - } else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") { + } else if (cache_quant_type_str == "cache_fp8") { CascadeAppendAttentionC8Kernel(meta_data, qkv, cache_k, @@ -301,7 +299,6 @@ void CascadeAppendAttentionKernel( causal, is_decoder, enable_prefill, - cache_quant_type_str, stream, out); } else if (cache_quant_type_str == "cache_int4_zp") { diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 2a56caa174..45c9d0a024 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -674,294 +674,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel( } } -template -__global__ void append_decode_cache_int8_rope_qk_norm_kernel( - const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, - // head_size] - uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, - // block_size, head_size // 2] - uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, - // block_size, head_size // 2] - T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] - const int* __restrict__ cu_seqlens_q, - const int* __restrict__ seq_lens, // [bsz] - const int* __restrict__ seq_lens_encoder, // [bsz] - const float* __restrict__ cos_emb, - const float* __restrict__ sin_emb, - T* __restrict__ cache_k_scale, - T* __restrict__ cache_v_scale, - const float* q_norm_weight, - const float* k_norm_weight, - const int max_seq_len, - const int max_blocks_per_seq, - const int num_heads, - const int block_size, - const float max_bound, - const float min_bound, - const int kv_num_heads, - const bool rope_3d, - const float rms_norm_eps) { - static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); - static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); - constexpr int NUM_WARPS = 4; - const int tid = threadIdx.x; - const int wid = tid / 32; - const int lane_id = tid % 32; - const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; - int q_head_idx, k_head_idx, v_idx; - const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim; - constexpr int half_head_size = HeadDim / 2; - const int start_token_idx = cu_seqlens_q[bid]; - if (seq_lens_encoder[bid] > 0) return; - const int write_seq_id = seq_lens[bid]; - if (write_seq_id == 0) return; - const int* block_table_now = nullptr; - - block_table_now = block_tables + bid * max_blocks_per_seq; - const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); - const int block_offset = write_seq_id % block_size; - - int cache_offset; - if (head_idx < num_heads) { - cache_offset = 0; - } else if (head_idx < num_heads + 2 * kv_num_heads) { - cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset; - } - T *cache_k_scale_now = cache_k_scale + cache_offset; - T *cache_v_scale_now = cache_v_scale + cache_offset; - - float thread_m2 = 0.0f; - float warp_m2 = 0.0f; - - if (head_idx < num_heads) { - // q - using LoadT = AlignedVector; - using LoadBiasT = AlignedVector; - using LoadOutScaleT = AlignedVector; - constexpr int HalfVecSize = VecSize / 2; - using LoadEmbT = AlignedVector; - - LoadT src_vec; - LoadBiasT out_vec; - LoadEmbT cos_emb_vec; - LoadEmbT sin_emb_vec; - const T* qkv_now = quant_qkv + start_token_idx * hidden_size; - T* qkv_out_now = qkv_out + start_token_idx * hidden_size; -#pragma unroll - for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; - head_bias += 32 * VecSize) { - const int bias_idx = head_idx * HeadDim + head_bias; - Load(&qkv_now[bias_idx], &src_vec); - // q rope - const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; - Load(&cos_emb[new_emb_idx], &cos_emb_vec); - Load(&sin_emb[new_emb_idx], &sin_emb_vec); -#pragma unroll - for (int i = 0; i < HalfVecSize; i++) { - // dequant + add_bias + rope - float input_left = static_cast(src_vec[2 * i]); - float input_right = static_cast(src_vec[2 * i + 1]); - - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; - thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; - out_vec[2 * i] = - static_cast(tmp1); - out_vec[2 * i + 1] = - static_cast(tmp2); - } - // qk norm - if (q_norm_weight) { - WelfordWarpAllReduce(thread_m2, &warp_m2); - float row_variance = - max(warp_m2 / HeadDim, 0.0f); - float row_inv_var = Rsqrt(row_variance + rms_norm_eps); - LoadOutScaleT q_norm_vec; - Load(&q_norm_weight[lane_id * VecSize], &q_norm_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * q_norm_vec[i]); - } - } - Store(out_vec, &qkv_out_now[bias_idx]); - } - } else if (head_idx < num_heads + 2 * kv_num_heads) { - // k - constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 - using LoadPadKVT = AlignedVector; - const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads; - if (block_offset == 0) { - // pad zero for this kv_head_idx for this block - LoadPadKVT pad_cache_vec; - *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); - if (head_idx < num_heads + kv_num_heads) { - constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; - constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; - const uint32_t tgt_idx = - (block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim + - lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; - for (int block_i = lane_id / num_vecs_per_head_dim; - block_i < block_size; - block_i += num_token_each_time) { - Store(pad_cache_vec, - &key_cache[tgt_idx + block_i * HeadDim]); - } - } else { - const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; - const int num_token_each_time = 32 / num_vecs_per_head_dim; - const uint32_t tgt_idx = - (block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size + - lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; - for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; - block_i += num_token_each_time) { - Store( - pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); - } - } - __syncwarp(); - } - - constexpr int K_VEC_SIZE = 4; - constexpr int HALF_K_VEC_SIZE = 2; - using LoadKVResT = AlignedVector; - using LoadKVT = AlignedVector; - using LoadT = AlignedVector; - using LoadBiasT = AlignedVector; - using LoadOutScaleT = AlignedVector; - using LoadEmbT = AlignedVector; - LoadKVResT cache_vec; - LoadT src_vec1, src_vec2; - LoadBiasT out_vec1, out_vec2; - LoadEmbT cos_emb_vec1, cos_emb_vec2; - LoadEmbT sin_emb_vec1, sin_emb_vec2; - - const T* qkv_now = quant_qkv + start_token_idx * hidden_size; - const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; - const int bias_idx = head_idx * HeadDim + head_bias; - Load(&qkv_now[bias_idx], &src_vec1); - Load(&qkv_now[bias_idx + 8], &src_vec2); - T scale = T(1.0f); - const int k_head_idx = head_idx - num_heads; - const int v_head_idx = head_idx - num_heads - kv_num_heads; - if (head_idx < num_heads + kv_num_heads) { - const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; - Load(&cos_emb[new_emb_idx], &cos_emb_vec1); - Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[new_emb_idx], &sin_emb_vec1); - Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); - } - - float input_left = static_cast(src_vec1[0]); - float input_right = static_cast(src_vec1[1]); - if (head_idx < num_heads + kv_num_heads) { - float cos_tmp = cos_emb_vec1[0]; - float sin_tmp = sin_emb_vec1[0]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; - thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; - out_vec1[0] = - static_cast(tmp1); - out_vec1[1] = - static_cast(tmp2); - } else { - out_vec1[0] = src_vec1[0]; - out_vec1[1] = src_vec1[1]; - } - - // rope - input_left = static_cast(src_vec2[0]); - input_right = static_cast(src_vec2[1]); - if (head_idx < num_heads + kv_num_heads) { - float cos_tmp = cos_emb_vec2[0]; - float sin_tmp = sin_emb_vec2[0]; - float tmp1 = input_left * cos_tmp - input_right * sin_tmp; - float tmp2 = input_right * cos_tmp + input_left * sin_tmp; - thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; - out_vec2[0] = static_cast(tmp1); - out_vec2[1] = static_cast(tmp2); - } else { - out_vec2[0] = src_vec2[0]; - out_vec2[1] = src_vec2[1]; - } - if (k_norm_weight) { - if (head_idx < num_heads + kv_num_heads) { - LoadOutScaleT k_norm_vec1, k_norm_vec2; - Load(&k_norm_weight[head_bias], &k_norm_vec1); - Load(&k_norm_weight[head_bias + 8], &k_norm_vec2); - // qk norm - WelfordWarpAllReduce(thread_m2, &warp_m2); - float row_variance = - max(warp_m2 / HeadDim, 0.0f); - float row_inv_var = Rsqrt(row_variance + rms_norm_eps); - - for (int i = 0; i < HALF_K_VEC_SIZE; i++) { - out_vec1[i] = static_cast(static_cast(out_vec1[i]) * row_inv_var * k_norm_vec1[i]); - out_vec2[i] = static_cast(static_cast(out_vec2[i]) * row_inv_var * k_norm_vec2[i]); - } - } - } - // reduce max, 1 head per warp - T local_max = -INFINITY; -#pragma unroll - for (int i = 0; i < HALF_K_VEC_SIZE; i++) { - local_max = __hmax(local_max, __habs(out_vec1[i])); - local_max = __hmax(local_max, __habs(out_vec2[i])); - } -#pragma unroll - for (int m_offset = 16; m_offset > 1; m_offset /= 2) { - local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); - } - - scale = __hdiv(448, local_max); - - if (lane_id == 0) { - if (head_idx < num_heads + kv_num_heads) { - cache_k_scale_now[0] = __hdiv(1, scale); - } else { - cache_v_scale_now[0] = __hdiv(1, scale); - } - } - -#pragma unroll - for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { - cache_vec[i] = QuantToC8(scale, out_vec1[i], max_bound, min_bound); - cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8(scale, out_vec2[i], max_bound, min_bound); - } - if (head_idx < num_heads + kv_num_heads) { - const int start_block_16 = - block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; - const uint32_t tgt_cache_idx = - block_idx * kv_num_heads * block_size * HeadDim + - kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + - lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; - Store(cache_vec, &key_cache[tgt_cache_idx]); - } else { - const uint32_t base_tgt_cache_idx = - block_idx * kv_num_heads * HeadDim * block_size + - kv_head_idx * HeadDim * block_size + - (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + - block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; - const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + - block_offset % 8 / 2 * 4 // per 4 - + block_offset % 16 / 8 * 2 // per 2 - + block_offset % 2; // per 1 - const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; - const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; - const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; - value_cache[tgt_cache_idx1] = cache_vec[0]; - value_cache[tgt_cache_idx2] = cache_vec[1]; - value_cache[tgt_cache_idx3] = cache_vec[2]; - value_cache[tgt_cache_idx4] = cache_vec[3]; - } - } -} - template __global__ void append_decode_cache_int8_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index c067efc759..d6643ca208 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -553,40 +553,9 @@ void DecoderWriteCacheWithRoPEKernel( q_norm_weight ? q_norm_weight.get().data() : nullptr, k_norm_weight ? k_norm_weight.get().data() : nullptr, rms_norm_eps); - } else if (cache_quant_type_str == "block_wise_fp8") { - constexpr int num_warps = 4; - const int all_warps = - ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; - dim3 grids(bsz, all_warps / num_warps); - append_decode_cache_int8_rope_qk_norm_kernel - <<>>( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast(cache_k_scale.get().data())), - const_cast(reinterpret_cast((cache_v_scale.get().data()))), - q_norm_weight.get().data(), - k_norm_weight.get().data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d, - rms_norm_eps); } else { PD_THROW( - "append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8"); + "append_decode_cache_rope_qk_norm not support cachekv quant yet"); } } else { if (cache_quant_type_str == "none") { @@ -717,37 +686,6 @@ void DecoderWriteCacheWithRoPEKernel( stream, use_neox_rotary_style, rope_3d); - } else if (cache_quant_type_str == "block_wise_fp8") { - constexpr int num_warps = 4; - const int all_warps = - ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; - dim3 grids(bsz, all_warps / num_warps); - append_decode_cache_int8_rope_qk_norm_kernel - <<>>( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast(cache_k_scale.get().data())), - const_cast(reinterpret_cast((cache_v_scale.get().data()))), - nullptr, - nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d, - rms_norm_eps); } else if (cache_quant_type_str == "cache_int4_zp") { append_decode_cache_int4_rope( reinterpret_cast(qkv_ptr), diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index c4e46a6e50..b38c177124 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -1232,411 +1232,6 @@ __global__ void append_write_cache_kv_c8_qkv( } } -template -__global__ void append_write_cache_kv_c8_qkv_dynamic( - uint8_t *__restrict__ cache_k, - uint8_t *__restrict__ cache_v, - const T *__restrict__ qkv_input, - T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size] - T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size] - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids, - const int *__restrict__ seq_lens_this_time, - const int *__restrict__ seq_lens_decoder, - const int *__restrict__ batch_id_per_token, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_tables, - const int max_seq_len, - const int max_blocks_per_seq, - const int num_heads, - const int kv_num_heads) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t pad_len = BLOCK_SIZE; - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const T cache_k_scale = cache_k_scales[kv_head_idx]; - const T cache_v_scale = cache_v_scales[kv_head_idx]; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids[btid]; - const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; - if (seq_len_this_time <= 0) { - return; - } - const int *block_table_now = nullptr; - - block_table_now = block_tables + batch_id * max_blocks_per_seq; - - const uint32_t num_rows_per_block = - NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE - const uint32_t start_len = seq_lens_decoder[batch_id]; - const uint32_t bf_pad_len = start_len % pad_len; - const uint32_t start_len_pad = start_len - bf_pad_len; - const uint32_t end_len = start_len + seq_len_this_time; - - const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block; - int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); - uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8; - - const uint32_t start_token_idx = cu_seqlens_q[batch_id]; - const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM; - const uint32_t kv_h_stride = HEAD_DIM; - __shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM]; - __shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM]; - __shared__ T v_scale_smem[BLOCK_SIZE]; - if (tile_start >= start_len) { - constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 - using LoadPadKVT = AlignedVector; - // pad zero for this kv_head_idx for this block - LoadPadKVT pad_cache_vec; - *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); - // reset k - constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE; - constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; - uint32_t tgt_idx = - (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM + - tid % num_vecs_per_head_k * KV_VEC_SIZE; - for (int block_i = tid / num_vecs_per_head_k; - block_i < BLOCK_SIZE; - block_i += num_token_each_time_k) { - Store(pad_cache_vec, - &cache_k[tgt_idx + block_i * HEAD_DIM]); - } - - // reset v - const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE; - const int num_token_each_time_v = 32 / num_vecs_per_head_v; - tgt_idx = - (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + - tid % num_vecs_per_head_v * KV_VEC_SIZE; - for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; - block_i += num_token_each_time_v) { - Store( - pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]); - } - } - smem_t k_smem(k_smem_ori); - smem_t v_smem(v_smem_ori); - - uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( - wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp - - /* - 0 | 1 - 2 | 3 - */ - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS; - /* - 0 | 2 - 1 | 3 - */ - uint32_t v_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, wid * num_frags_v * 2 + tid / 16); - - // load kv gmem to smem - const uint32_t real_start_token_idx = start_token_idx - bf_pad_len + - tile_id * num_rows_per_block + - wid * num_frags_z * 16 + tid / 8; - uint32_t k_read_idx = real_start_token_idx * kv_batch_stride + - (num_heads + kv_head_idx) * kv_h_stride + - tid % 8 * num_elems_per_128b(); - uint32_t v_read_idx = real_start_token_idx * kv_batch_stride + - (num_heads + kv_num_heads + kv_head_idx) * kv_h_stride + - tid % 8 * num_elems_per_128b(); -#pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { -#pragma unroll - for (uint32_t j = 0; j < 4; ++j) { -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y / 4; - ++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) - if (chunk_start >= start_len && chunk_start < end_len) { - k_smem.load_128b_async( - kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len); - v_smem.load_128b_async( - kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len); - } - kv_smem_offset_w = - k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy); - k_read_idx += 8 * num_elems_per_128b(); - v_read_idx += 8 * num_elems_per_128b(); - } - kv_smem_offset_w = - k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) - - 2 * num_frags_y; - chunk_start += 4; - k_read_idx += - 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); - v_read_idx += - 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); - } - } - commit_group(); - wait_group<0>(); - __syncthreads(); - - // reduce scale - // 16 rows per warp - uint32_t kv_reduce_frag[4]; - T *kv_reduce_frag_T = reinterpret_cast(kv_reduce_frag); - - T k_local_max_value[num_frags_z * 2]; - T v_local_max_value[num_frags_z * 2]; -#pragma unroll - for (int i = 0; i < num_frags_z * 2; i++) { - k_local_max_value[i] = -INFINITY; - } -#pragma unroll - for (int i = 0; i < num_frags_z * 2; i++) { - v_local_max_value[i] = -INFINITY; - } - const int num_kv_heads = gridDim.z; - const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE; - T *cache_k_scale_now = cache_k_scales + scale_offset; - T *cache_v_scale_now = cache_v_scales + scale_offset; - // k scale -#pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - // reduce per thread, 4 threads each row - k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); -#pragma unroll - for (int i = 0; i < 4; i++) { - k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]); - } -#pragma unroll - for (int i = 0; i < 4; i++) { - k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]); - } - k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); - } - // reduce per row - for (int i = 0; i < 2; i++) { - T local_max_value = __habs(k_local_max_value[fz * 2 + i]); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2)); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1)); - // used for quant - k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); - } - // store - if (tid % 4 == 0) { - const int offset_now = wid * num_frags_z * 16 + tid / 4; - // used for dequant - if (tile_start + offset_now >= start_len) { - if (tile_start + offset_now < end_len) { - cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]); - } else { - cache_k_scale_now[offset_now] = 0; - } - } - if (tile_start + offset_now + 8 >= start_len) { - if (tile_start + offset_now + 8 < end_len) { - cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]); - } else { - cache_k_scale_now[offset_now + 8] = 0; - } - } - } - __syncthreads(); - k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 - } - // v scale - #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - // reduce per thread, 4 threads each row - v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); -#pragma unroll - for (int i = 0; i < 4; i++) { - v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]); - } -#pragma unroll - for (int i = 0; i < 4; i++) { - v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]); - } - k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); - } - // reduce per row - for (int i = 0; i < 2; i++) { - T local_max_value = __habs(v_local_max_value[fz * 2 + i]); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2)); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1)); - v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); - } - // store - if (tid % 4 == 0) { - const int offset_now = wid * num_frags_z * 16 + tid / 4; - // used for dequant - if (tile_start + offset_now >= start_len) { - if (tile_start + offset_now < end_len) { - cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]); - v_scale_smem[offset_now] = v_local_max_value[fz * 2]; - } else { - cache_v_scale_now[offset_now] = 0; - v_scale_smem[offset_now] = 0; - } - } - if (tile_start + offset_now + 8 >= start_len) { - if (tile_start + offset_now + 8 < end_len) { - cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]); - v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1]; - } else { - cache_v_scale_now[offset_now + 8] = 0; - v_scale_smem[offset_now + 8] = 0; - } - } - } - __syncthreads(); - k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 - } - __syncthreads(); - - // mask, quant, store - using LoadKVT = AlignedVector; - LoadKVT cache_vec1; - LoadKVT cache_vec2; - - uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4; - uint32_t kv_frag[4]; - const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t write_b_stride = HEAD_DIM; - const uint32_t write_d_stride = BLOCK_SIZE; - uint32_t k_write_idx = block_id * write_n_stride + - kv_head_idx * write_h_stride + - (wid * num_frags_z * 16 + tid / 4) * write_b_stride + - tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit -#pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { - uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride; -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - uint32_t k_write_idx_now = k_write_idx_now_z + - fy % 2 * 8 * write_b_stride + - fy / 2 * 32; // + fy % 2 * 16; - // load - k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); - // quant - T *k_frag_T = reinterpret_cast(kv_frag); - if (bf_pad_len != 0) { - Load(cache_k + k_write_idx_now, &cache_vec1); - Load(cache_k + k_write_idx_now + 16, &cache_vec2); - } -#pragma unroll - for (uint32_t v_id = 0; v_id < 8; ++v_id) { - uint8_t uint_quant_value; - if (chunk_start_k + (v_id / 4) * 8 >= start_len && - chunk_start_k + (v_id / 4) * 8 < end_len) { - uint_quant_value = QuantToC8(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f); - } else { - uint_quant_value = 0; - } - if (bf_pad_len != 0) { - if (v_id < 4) { - cache_vec1[v_id] |= uint_quant_value; - } else { - cache_vec2[v_id % 4] |= uint_quant_value; - } - } else { - if (v_id < 4) { - cache_vec1[v_id] = uint_quant_value; - } else { - cache_vec2[v_id - 4] = uint_quant_value; - } - } - } - // store - Store(cache_vec1, cache_k + k_write_idx_now); - Store(cache_vec2, cache_k + k_write_idx_now + 16); - k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); - } - k_smem_offset_r = - k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) - - 2 * num_frags_y; - chunk_start_k += 16; - } - - uint32_t chunk_start_v = tile_start + tid % 4 * 2; - uint32_t v_write_idx = block_id * write_n_stride + - kv_head_idx * write_h_stride + - (wid * num_frags_v * 16 + tid / 4) * write_d_stride + - tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit - const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS; - T v_scales[num_frags_z_v * 4]; - for (int v_i = 0; v_i < num_frags_z_v; v_i++) { - const int offset = v_i * 16; - const int t_offset = tid % 4 * 2; - v_scales[v_i * 4] = v_scale_smem[offset + t_offset]; - v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1]; - v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8]; - v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9]; - } - -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_v; ++fy) { - uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride; -#pragma unroll - for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) { - uint32_t v_write_idx_now = v_write_idx_now_v + - fz % 2 * 8 * write_d_stride + - fz / 2 * 32; // + fz % 2 * 16; - // load - v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag); - // quant - T *v_frag_T = reinterpret_cast(kv_frag); - if (bf_pad_len != 0) { - Load(cache_v + v_write_idx_now, &cache_vec1); - Load(cache_v + v_write_idx_now + 16, &cache_vec2); - } -#pragma unroll - for (uint32_t v_id = 0; v_id < 8; ++v_id) { - uint8_t uint_quant_value; - if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len && - chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) { - uint_quant_value = QuantToC8(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f); - // store now - } else { - uint_quant_value = 0; - } - if (bf_pad_len != 0) { - if (v_id < 4) { - cache_vec1[v_id] |= uint_quant_value; - } else { - cache_vec2[v_id % 4] |= uint_quant_value; - } - } else { - if (v_id < 4) { - cache_vec1[v_id] = uint_quant_value; - } else { - cache_vec2[v_id % 4] = uint_quant_value; - } - } - } - // store - Store(cache_vec1, cache_v + v_write_idx_now); - Store(cache_vec2, cache_v + v_write_idx_now + 16); - chunk_start_v += 16; - v_smem_offset_r = - k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r); - } - v_smem_offset_r = k_smem.advance_offset_by_column<2>( - v_smem_offset_r, wid * num_frags_v + fy) - - 16 * num_frags_z_v * num_vecs_per_head; - chunk_start_v -= 16 * num_frags_z_v; - } -} - // Write Cache KV in Append template ::type; auto max_blocks_per_seq = meta_data.max_blocks_per_seq; auto num_tokens = meta_data.token_nums; auto num_heads = meta_data.q_num_heads; @@ -2433,77 +2027,49 @@ void CascadeAppendWriteCacheKVC8QKV( dim3 blocks(32, num_warps); const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2; - if (cache_quant_type != "block_wise_fp8") { - auto kernel_fn = append_write_cache_kv_c8_qkv; - if (cache_quant_type == "cache_fp8") { - kernel_fn = append_write_cache_kv_c8_qkv; - } - if (is_scale_channel_wise) { - kernel_fn = append_write_cache_kv_c8_qkv; - } - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - qkv.data(), - cache_k_scale.data(), - cache_v_scale.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); - } else { - auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic; - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - reinterpret_cast(qkv.data()), - const_cast(reinterpret_cast(cache_k_scale.data())), - const_cast(reinterpret_cast(cache_v_scale.data())), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); + auto kernel_fn = append_write_cache_kv_c8_qkv; + if (is_fp8) { + kernel_fn = append_write_cache_kv_c8_qkv; + } + if (is_scale_channel_wise) { + kernel_fn = append_write_cache_kv_c8_qkv; } + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel_fn<<>>(cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); } template diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index b0d66a2913..5af84e73f3 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -167,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel( stream, key_cache_out, value_cache_out); - } else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") { + } else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") { DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { CascadeAppendWriteCacheKVC8QKV( @@ -187,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel( num_blocks, max_seq_len, is_scale_channel_wise, - cache_quant_type_str, + cache_quant_type_str == "cache_fp8", stream, key_cache_out, value_cache_out); diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 0388b9fb6c..40e682963b 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -1000,7 +1000,7 @@ std::vector GQARopeWriteCacheKernel( stream, const_cast(&key_cache), const_cast(&value_cache)); - } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") { + } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { CascadeAppendWriteCacheKVC8QKV( meta_data, *const_cast(&key_cache), @@ -1018,7 +1018,7 @@ std::vector GQARopeWriteCacheKernel( kv_num_blocks_data, max_seq_len, false, // is_scale_channel_wise - cache_quant_type, + cache_quant_type == "cache_fp8", // is_fp8 stream, const_cast(&key_cache), const_cast(&value_cache)); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu index 757cccaf97..e860a04626 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu @@ -56,7 +56,6 @@ CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -104,6 +103,5 @@ CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu index 54b0b0be4f..3b61ecd16b 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu @@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu index 153b81ee05..4d7b11d99c 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu @@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); @@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, - const std::string& cache_quant_type_str, cudaStream_t& stream, paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 12d86dade8..13874a3f94 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -441,15 +441,6 @@ __forceinline__ __host__ __device__ void vec_cast( PD_THROW("not support the group_size", group_size); \ } -#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ - if (is_dynamic_cfp8) { \ - constexpr bool IsDynamicC8 = true; \ - __VA_ARGS__ \ - } else { \ - constexpr bool IsDynamicC8 = false; \ - __VA_ARGS__ \ - } - #define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 64023e7e25..d201f06e60 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -222,17 +222,6 @@ def forward_mixed( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, ) - cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") - if cache_quant_type_str == "block_wise_fp8": - cache_k = forward_meta.caches[4 * layer.layer_id] - cache_v = forward_meta.caches[4 * layer.layer_id + 1] - cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2] - cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3] - else: - cache_k = forward_meta.caches[2 * layer.layer_id] - cache_v = forward_meta.caches[2 * layer.layer_id + 1] - cache_k_scales = getattr(layer, "cache_k_scale", None) - cache_v_scales = getattr(layer, "cache_v_scale", None) if self.use_output: quant_max_bound = getattr(layer, "quant_max_bound", 0.0) @@ -271,8 +260,8 @@ def forward_mixed( append_attention_with_output( qkv, - cache_k, - cache_v, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -295,8 +284,8 @@ def forward_mixed( metadata.attn_mask, layer.qkv_bias, layer.qkv_scale, - cache_k_scales, - cache_v_scales, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), getattr(layer, "cache_k_out_scale", None), getattr(layer, "cache_v_out_scale", None), getattr(layer, "cache_k_zp", None), @@ -327,8 +316,8 @@ def forward_mixed( else: res = append_attention( qkv, - cache_k, - cache_v, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -350,8 +339,8 @@ def forward_mixed( metadata.attn_mask, layer.qkv_bias, layer.qkv_scale, - cache_k_scales, - cache_v_scales, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), getattr(layer, "cache_k_out_scale", None), getattr(layer, "cache_v_out_scale", None), getattr(layer, "cache_k_zp", None), diff --git a/fastdeploy/model_executor/layers/quantization/kv_cache.py b/fastdeploy/model_executor/layers/quantization/kv_cache.py index d7727da5ca..d560e6122e 100644 --- a/fastdeploy/model_executor/layers/quantization/kv_cache.py +++ b/fastdeploy/model_executor/layers/quantization/kv_cache.py @@ -33,7 +33,6 @@ class KvCacheQuantzationTypes(str, Enum): INT8 = "int8" FP8 = "float8_e4m3fn" - BLOCK_WISE_FP8 = "block_wise_fp8" INT8_ZP = "int8_zp" INT4_ZP = "int4_zp" FP8_ZP = "float8_e4m3fn_zp" @@ -63,11 +62,7 @@ def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_poi if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP: self.max_bound = 127.0 - elif ( - self.quant_type == KvCacheQuantzationTypes.FP8 - or self.quant_type == KvCacheQuantzationTypes.FP8_ZP - or self.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8 - ): + elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP: self.max_bound = 448.0 elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP: self.max_bound = 7.0 @@ -183,17 +178,12 @@ def create_weights(self, layer: nn.Layer, state_dict): layer.cache_quant_type_str = "cache_int4_zp" layer.quant_max_bound = 7.0 layer.quant_min_bound = -7.0 - elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8: - layer.cache_quant_type_str = "block_wise_fp8" - layer.quant_max_bound = 448.0 - layer.quant_min_bound = -448.0 else: raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented") - if "block_wise" not in layer.cache_quant_type_str: - self.load_scale(layer, state_dict) - if self.cache_quant_config.has_zero_point: - self.load_zp(layer, state_dict) + self.load_scale(layer, state_dict) + if self.cache_quant_config.has_zero_point: + self.load_zp(layer, state_dict) def apply(self, layer): """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 53591ca597..8453d51902 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1131,8 +1131,6 @@ def initialize_kv_cache(self, profile: bool = False) -> None: kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) - if kv_cache_quant_type == "block_wise_fp8": - kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): @@ -1160,17 +1158,6 @@ def initialize_kv_cache(self, profile: bool = False) -> None: fill_value=0, dtype=cache_type, ) - if kv_cache_quant_type == "block_wise_fp8": - cache_kvs[f"key_cache_scales_{i}"] = paddle.full( - shape=kv_cache_scale_shape, - fill_value=0, - dtype=paddle.get_default_dtype(), - ) - cache_kvs[f"value_cache_scales_{i}"] = paddle.full( - shape=kv_cache_scale_shape, - fill_value=0, - dtype=paddle.get_default_dtype(), - ) self.share_inputs["caches"] = list(cache_kvs.values()) for value in cache_kvs.values(): del value diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index fcae195095..fe04c125e1 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import time import unittest @@ -21,7 +20,6 @@ from paddle.incubate.nn.functional import fused_rms_norm paddle.seed(10) -np.random.seed(10) class RopeEmbedding: @@ -336,7 +334,7 @@ def setUp(self): self.name = "TestAppendGroupQueryAttnWithRope" self.place = paddle.CUDAPlace(0) self.batch_size = 1 - self.q_num_head = 16 + self.q_num_head = 12 self.kv_num_head = 2 self.seq_len = 64 self.max_dec_len = 64 @@ -349,10 +347,9 @@ def setUp(self): self.max_seq_len = self.seq_len + self.max_dec_len self.softmax_scale = self.dim_head**-0.5 self.rope_theta = 10000 - self.dtype = "bfloat16" + self.dtype = "float16" self.use_qk_norm = True self.use_mask_offset = False - self.use_dynamic_quant = False self.init_tensor() def init_tensor(self): @@ -402,23 +399,8 @@ def init_tensor(self): ) self.scale = 1.0 / np.sqrt(self.dim_head) - if self.use_dynamic_quant: - self.cache_scale_shape = ( - self.max_block_num, - self.kv_num_head, - self.blocksize, - ) - self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") - self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") - self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) - self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) - self.key_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) - self.value_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) - else: - self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) - self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) - self.key_cache_scale = None - self.value_cache_scale = None + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") for i in range(self.batch_size): need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize @@ -441,7 +423,6 @@ def init_tensor(self): def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None): paddle.disable_static() - print("use_dynamic_quant: ", self.use_dynamic_quant) self.token_num = self.seq_len * self.batch_size q, k, v, qkv = get_qkv_and_qkv_concat_tensor( self.batch_size, @@ -498,67 +479,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask self.blocksize, speculate_max_draft_token_num + 1, ) - if self.use_dynamic_quant: - cache_quant_type = "block_wise_fp8" - else: - cache_quant_type = "none" - - if self.use_dynamic_quant: - qkv_copy = copy.deepcopy(qkv) - append_attention( - qkv_copy, - self.cache_k_T, - self.cache_v_T, - self.seq_lens_encoder, - self.seq_lens_decoder, - self.seq_lens_this_time, - self.padding_offset, - self.cum_offset, - self.block_tables, - self.encoder_batch_ids, - self.encoder_tile_ids_per_batch, - self.encoder_num_blocks_x_cpu, - self.kv_batch_ids, - self.kv_tile_ids_per_batch, - self.kv_num_blocks_x_cpu, - self.decoder_batch_ids, - self.decoder_tile_ids_per_batch, - self.decoder_num_blocks_cpu, - self.max_len_tensor_cpu, - self.max_len_kv_cpu, - self.rope_emb, # rope_emb - None, # attn_mask - None, # qkv_bias - None, # qkv_out_scales - None, # cache_k_quant_scales - None, # cache_v_quant_scales - None, # cache_k_dequant_scales - None, # cache_v_dequant_scales - None, # cache_k_zp - None, # cache_v_zp - None, # linear_shift - None, # linear_smooth - self.mask_offset, # mask_offset - None, # kv_signal_data - q_norm_weight, # q_norm_weight - k_norm_weight, # k_norm_weight - 1e-6, - "fp16", - "none", - self.use_neox_rotary_style, - False, - self.max_seq_len, - 0.0, # quant_min_bound - 0.0, # quant_max_bound - -1, # out_linear_in_scale - 64, # encoder_block_shape_q - 16, # decoder_block_shape_q - 32768, # max_partition_size - 32768, # encoder_max_partition_size - speculate_max_draft_token_num + 1, # speculate_max_draft_token_num - True, # causal - False, # speculate_decoder - ) # Warm up WARM_UP = 1 @@ -577,23 +497,23 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask self.padding_offset, self.cum_offset, self.block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, + self.encoder_batch_ids, + self.encoder_tile_ids_per_batch, + self.encoder_num_blocks_x_cpu, + self.kv_batch_ids, + self.kv_tile_ids_per_batch, + self.kv_num_blocks_x_cpu, self.decoder_batch_ids, self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - max_len_kv, + self.max_len_kv_cpu, self.rope_emb, # rope_emb None, # attn_mask None, # qkv_bias None, # qkv_out_scales - self.key_cache_scale, # cache_k_quant_scales - self.value_cache_scale, # cache_v_quant_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales None, # cache_k_dequant_scales None, # cache_v_dequant_scales None, # cache_k_zp @@ -606,7 +526,7 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask k_norm_weight, # k_norm_weight 1e-6, "fp16", - cache_quant_type, + "none", # cache_quant_type self.use_neox_rotary_style, False, self.max_seq_len, @@ -624,6 +544,13 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask paddle.device.synchronize() end_time = time.time() print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms") + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) np.testing.assert_allclose( out.numpy(), out_.numpy(), @@ -652,22 +579,13 @@ def test_all(self): if self.use_mask_offset: print("encoder mask_offset: ", self.mask_offset) self.cmp_append_attention(attn_mask=self.attention_mask) - if self.use_dynamic_quant: - naive_cache_k, naive_cache_v = block_cache_to_naive_cache( - self.cache_k_T, - self.cache_v_T, - self.batch_size, - self.block_tables, - self.seq_len, - ) - else: - naive_cache_k, naive_cache_v = block_cache_to_naive_cache( - self.cache_k, - self.cache_v, - self.batch_size, - self.block_tables, - self.seq_len, - ) + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) # decoder self.seq_lens_decoder[:] = self.seq_lens_encoder self.seq_lens_encoder[:] = 0 @@ -702,10 +620,10 @@ def test_all(self): class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope): def setUp(self): paddle.disable_static() - self.name = "TestAppendGroupQueryAttnWithNeoXRope" + self.name = "TestAppendGroupQueryAttnWithRope" self.place = paddle.CUDAPlace(0) self.batch_size = 1 - self.q_num_head = 16 + self.q_num_head = 12 self.kv_num_head = 2 self.seq_len = 64 self.max_dec_len = 64 @@ -721,33 +639,6 @@ def setUp(self): self.dtype = "float16" self.use_qk_norm = False self.use_mask_offset = True - self.use_dynamic_quant = False - self.init_tensor() - - -class TestAppendGroupQueryAttnWithRopeDyCfp8(TestAppendGroupQueryAttnWithRope): - def setUp(self): - paddle.disable_static() - self.name = "TestAppendGroupQueryAttnWithRopeDyCfp8" - self.place = paddle.CUDAPlace(0) - self.batch_size = 1 - self.q_num_head = 16 - self.kv_num_head = 2 - self.seq_len = 64 - self.max_dec_len = 64 - self.dim_head = 128 - self.q_hid_dim = self.q_num_head * self.dim_head - self.kv_hid_dim = self.kv_num_head * self.dim_head - self.blocksize = 64 - self.use_neox_rotary_style = False - # max_seq_len = self.seq_len + self.max_dec_len - self.max_seq_len = self.seq_len + self.max_dec_len - self.softmax_scale = self.dim_head**-0.5 - self.rope_theta = 10000 - self.dtype = "bfloat16" - self.use_qk_norm = True - self.use_mask_offset = False - self.use_dynamic_quant = True self.init_tensor()