From e030c80ca19f879ecf544f4254042c7c89f7fc8e Mon Sep 17 00:00:00 2001 From: "river.li" Date: Fri, 29 Aug 2025 18:11:10 +0800 Subject: [PATCH 01/96] Init PA CM Impl(1st/2nd token and kvcache update) --- .../intel_gpu/primitives/paged_attention.hpp | 2 + .../graph/impls/cm/include/cm_sdpa_common.hpp | 292 ++++++++++- .../graph/impls/cm/pa_kv_cache_update_ref.cm | 103 ++++ .../src/graph/impls/cm/pa_multi_token.cm | 147 ++++++ .../src/graph/impls/cm/pa_single_token.cm | 314 +++++++++++ .../impls/cm/pa_single_token_finalization.cm | 54 ++ .../src/graph/impls/cm/paged_attention.cpp | 159 ++++++ .../src/graph/impls/cm/paged_attention.hpp | 64 +++ .../graph/impls/cm/paged_attention_gen.cpp | 492 ++++++++++++++++++ .../graph/impls/cm/paged_attention_gen.hpp | 96 ++++ .../impls/ocl_v2/sdpa/paged_attention_opt.hpp | 4 + .../graph/registry/paged_attention_impls.cpp | 2 + .../test_cases/paged_attention_gpu_test.cpp | 25 + 13 files changed, 1752 insertions(+), 2 deletions(-) create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index d8c417d611945b..f5fe561dc84046 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -10,6 +10,8 @@ namespace cldnn { +#define ENABLE_PA_CM_PATH 1 + struct paged_attention : public primitive_base { CLDNN_DECLARE_PRIMITIVE(paged_attention) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp index 74fe045cfff4a6..b4f679d198f4a2 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp @@ -236,7 +236,7 @@ inline matrix ugemm_KQ(uint slm_K, matrix_ref inline void ugemm_PV0(uint slm_V, matrix_ref P, matrix_ref rO, uint slm_offset = 0) { constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; - + auto P2 = P.format(); #pragma unroll for(int k = 0, ri = 0; k < _head_size; k += REG_N, ri += num_P_tiles) { @@ -312,6 +312,74 @@ vector online_softmax_update(matrix_ref St, vector_r return max_comp; } +#ifdef CM_HAS_LSC_UNTYPED_2D + #define cm_load_normal cm_load + #define cm_load_transpose cm_load + #define cm_load_vnni cm_load + #define cm_store_normal cm_store +#else + // simulation of LSC API using SVM API + template + inline void cm_load_normal(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, Res.select(i*BlockW)); + } + } + + template + inline void cm_load_transpose(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + matrix temp; + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, temp[i]); + } + Transpose2DMatrix(temp, Res.format()); + } + + // in VNNI case, NBlocks is increasing along X dimension (increase cache-line usage) + template + inline void cm_load_vnni(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1 || NBlocks == 2); + // each block must be a full XMX B matrix + static_assert(BlockH == REG_K); + static_assert(BlockW == REG_N); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + matrix temp; + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, temp[i]); + } + + auto out_vnni = Res.format(); + #pragma unroll + for(int i = 0; i < NBlocks; i ++) { + out_vnni.select(i*(BlockH/2), 0) = temp.select(0, i*BlockW); + out_vnni.select(i*(BlockH/2), 1) = temp.select(1, i*BlockW); + } + } + + template + inline void cm_store_normal(const lsc::block_2d_desc &Desc, vector_ref Res) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_write(base + i * pitch, Res.select(i*BlockW)); + } + } +#endif + + + //=============================================================================================== template constexpr void apply_causal_mask(matrix_ref St) { @@ -322,6 +390,7 @@ constexpr void apply_causal_mask(matrix_ref St) { } #ifdef CM_HAS_LSC_UNTYPED_2D + template void sdpa_kernel_lsc( uint slm_K, @@ -482,7 +551,6 @@ void sdpa_kernel_lsc( } } - template void sdpa_kernel_lsc_prefetch( int wg_local_id, @@ -662,6 +730,226 @@ void sdpa_kernel_lsc_prefetch( cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); } } + +template +void pa_kernel_lsc_prefetch( + int wg_local_id, + int q_start, + int kv_stop, // + int q_len, //q_step + int kv_len, //not used for now + svmptr_t q_base [[type("svmptr_t")]], + svmptr_t k_base [[type("svmptr_t")]], + svmptr_t v_base [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + svmptr_t sparse_mask_base [[type("svmptr_t")]], +#endif + svmptr_t o_base [[type("svmptr_t")]], + int32_t past_lens, + int32_t* block_indices [[type("svmptr_t")]]) { + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + // constexpr uint k_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + // constexpr uint v_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + //[block_num, kv_heads, block_size, head_size] + constexpr uint k_pitch = head_size * sizeof(half); + constexpr uint v_pitch = k_pitch; + + vector cur_max; + vector cur_sum; + + bool need_comp = false; + + cur_max = -3e38f; + cur_sum = 0; + constexpr int num_P_tiles = REG_N / REG_M; + matrix rQ; + matrix rO; + + auto q_tokens_left = q_len; + static_assert(q_step == REG_N); + static_assert(kv_step == REG_K); + + if (q_tokens_left < 0) q_tokens_left = 0; + if (q_tokens_left > q_step) q_tokens_left = q_step; + +#if SPARSE_BLOCK_SIZE > 1 + // printf("wg:%d.%d q: %d, +%d kv: %d, x-attn: %p\n", 0, wg_local_id, q_start, q_tokens_left, kv_stop, reinterpret_cast(sparse_mask_base)); +#endif + + if (q_tokens_left > 0) { + lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); + #pragma unroll + for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { + cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + + lsc::block_2d_desc b2dK(k_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + + static_assert(wg_local_size == 16); + lsc::block_2d_desc prefetch_K(k_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc prefetch_V(v_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + constexpr int blk_stride = CMFLA_NUM_KV_HEADS*CMFLA_HEAD_SIZE*CMPA_BLOCK_SZ; + int causal_left = q_start+past_lens; + + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) { + auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; + //For the last step, duplicate prefetch here. + uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); + auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; + //# St = k @ Qt + matrix St; // = ugemm_KQ(slm_K, rQ, slm_offset); + { + constexpr int num_K = kv_step/REG_M; + auto St2 = St.format(); + + matrix Kmat; + //cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + + prefetch_K.set_base_ptr((reinterpret_cast(k_base)+prefetch_block_id*blk_stride)); + prefetch_K.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + cm_prefetch(prefetch_K.set_block_x(0)); + +#if SPARSE_BLOCK_SIZE > 1 + { + auto kv_start_block = kv_pos/ SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); + if (!sparse_mask) { + if constexpr (use_causal_mask) { + causal_left -= kv_step; + } + continue; + } + } +#endif + + b2dK.set_base_ptr((reinterpret_cast(k_base)+cur_block_id*blk_stride)); + b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); + cm_load(Kmat.format(), b2dK.set_block_x(0)); + #pragma unroll + for(int k = 0; k < num_K; k++) + St2.row(k) = cm_dpas( + 0, + rQ[0].format(), + Kmat[k].format()); + + #pragma unroll + for(int ri = 1; ri < head_size/REG_K; ri++) { + //cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); + cm_prefetch(prefetch_K.set_block_x(ri*REG_K)); + cm_load(Kmat.format(), b2dK.set_block_x(ri*REG_K)); + #pragma unroll + for(int k = 0; k < num_K; k++) { + St2.row(k) = cm_dpas( + St2.row(k), + rQ[ri].format(), + Kmat[k].format()); + } + } + } + if constexpr (use_causal_mask) { + // since kv_step == q_step == 16, causal_left is n*kv_step + if (causal_left == 0) { + apply_causal_mask<1>(St); + } else if (causal_left < 0) { + St = -3.4e38f; + } + causal_left -= kv_step; + } else { + int kv_tokens = kv_stop - kv_pos; + // LSC ensures no overflow-access, but mask off k-tails attn-score is still required + for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; + } + + // show(St); + auto max_comp = online_softmax_update(St, cur_max, cur_sum); + + matrix P; + Transpose2DMatrix(St, P); + + prefetch_V.set_base_ptr((reinterpret_cast(v_base)+prefetch_block_id*blk_stride)); + prefetch_V.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + + b2dV.set_base_ptr((reinterpret_cast(v_base)+cur_block_id*blk_stride)); + b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + if (need_comp == false) { + // ugemm_PV0(slm_V, P, rO, slm_offset); + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_prefetch(prefetch_V.set_block_x(k)); + cm_load(Vmat.format(), b2dV.set_block_x(k)); + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + 0, + Vmat.format(), + P2.row(p).format()); + // show(rO[ri + p].format()); + } + } + + need_comp = true; + } + else { + //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + + cm_prefetch(prefetch_V.set_block_x(k)); + cm_load(Vmat.format(), b2dV.set_block_x(k)); + + //# compensate cur_O + // matrix rO; + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < REG_M; r++) + cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); + } + + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + rO[ri + p].format(), + Vmat.format(), + P2.row(p).format()); + // show(rO[ri + p].format()); + } + } + } + } + if (q_tokens_left == 0) return; + + //# save cur_O/cur_sum.transpose(0, 1) + matrix cur_O_f16; + cur_sum = cm_inv(cur_sum); + + lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); + + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < cO.n_rows(); r++) { + cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + + } + } + b2dO.set_block_x(k); + cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); + cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); + } +} #endif template diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm new file mode 100644 index 00000000000000..18623b400a31e8 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -0,0 +1,103 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include +#include + +#ifndef ATTR +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +constexpr uint wg_size = WG_SIZE; + +// extern "C" _GENX_MAIN_ void pa_kv_cache_update( +extern "C" _GENX_MAIN_ void KERNEL_NAME( + const half* key [[type("svmptr_t")]], + const half* value [[type("svmptr_t")]], + const int32_t* past_lens [[type("svmptr_t")]], + const int32_t* block_indices [[type("svmptr_t")]], + const int32_t* block_indices_begins [[type("svmptr_t")]], + const int32_t* subsequence_begins [[type("svmptr_t")]], + half* key_cache [[type("svmptr_t")]], + half* value_cache [[type("svmptr_t")]], + uint32_t key_pitch, + uint32_t value_pitch, + uint32_t batch_size_in_sequences) { + // # key: [batch_size_in_tokens, num_kv_heads * k_head_size] + // # value [batch_size_in_tokens, num_kv_heads * v_head_size] + // # key_cache: [num_blocks, num_heads, block_size, k_head_size] + // # value_cache: [num_blocks, num_heads, block_size, v_head_size] + // + // # past_lens: [sequences_num] + // # subsequence_begins: [sequences_num + 1] + // # block_indices: [used_blocks_num] + // # block_indices_begins: [sequences_num + 1] + + // wg_count = aligned_to(batch_size_in_tokens, wg_size) // wg_size + // # GWS [1, num_heads, wg_count * wg_size] + // # LWS [1, 1, wg_size] + + const auto head_idx = cm_group_id(1); + const auto wg_id = cm_group_id(2); + const auto wg_local_id = cm_local_id(2); + const auto local_size = cm_local_size(2); + + // static_assert(local_size == wg_size); + + // const uint token_idx = wg_id * local_size + wg_local_id; + const uint token_idx = cm_global_id(2); + + // token_idx -> subsequence_idx + if (token_idx >= subsequence_begins[batch_size_in_sequences]) return; + uint subsequence_idx = 0; + for (uint i = 0; i < batch_size_in_sequences; i++) { + if (token_idx >= subsequence_begins[i] && token_idx < subsequence_begins[i + 1]) { + subsequence_idx = i; + break; + } + } + + // printf("wg:%d.%d, token_idx: %d, subsequence_idx: %d\n", wg_id, wg_local_id, token_idx, subsequence_idx); + + const uint subsequence_begin_idx = subsequence_begins[subsequence_idx]; + + const uint past_len = past_lens[subsequence_idx]; + + const uint current_block_idx = (past_len + token_idx - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; + const uint token_start_pos = (past_len + token_idx - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE; + + const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx; + + { + uint block_k_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; + uint key_out_offset = block_k_base_offset + token_start_pos * K_HEAD_SIZE; + uint key_in_offset = token_idx * key_pitch + head_idx * K_HEAD_SIZE; + + vector key_data; + key_data.format() = cm_ptr_load((int*)key, key_in_offset * (int)sizeof(half)); + cm_ptr_store((int*)key_cache, key_out_offset * (int)sizeof(half), key_data.format()); + } + { + uint block_v_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_V_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; + uint value_out_offset = block_v_base_offset + token_start_pos * V_HEAD_SIZE; + uint value_in_offset = token_idx * value_pitch + head_idx * V_HEAD_SIZE; + + vector value_data; + value_data.format() = cm_ptr_load((int*)value, value_in_offset * (int)sizeof(half)); + cm_ptr_store((int*)value_cache, value_out_offset * (int)sizeof(half), value_data.format()); + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm new file mode 100644 index 00000000000000..c69a19bcab5945 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm @@ -0,0 +1,147 @@ + +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "cm_sdpa_common.hpp" + +#ifdef CM_HAS_LSC_UNTYPED_2D +#define USE_LSC 1 +#else +#define USE_LSC 0 +#endif + +//extern "C" _GENX_MAIN_ void pa_multi_token( +extern "C" _GENX_MAIN_ void KERNEL_NAME( + half* query [[type("svmptr_t")]], + half* key [[type("svmptr_t")]], + half* value [[type("svmptr_t")]], + int32_t* past_lens [[type("svmptr_t")]], + int32_t* block_indices [[type("svmptr_t")]], + int32_t* block_indices_begins [[type("svmptr_t")]], + int32_t* subsequence_begins [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + bool* sparse_block_mask [[type("svmptr_t")]], +#endif + half* output [[type("svmptr_t")]], + int q_len) { + constexpr int is_causal = CMFLA_IS_CAUSAL; + constexpr int num_heads = CMFLA_NUM_HEADS; + constexpr int head_size = CMFLA_HEAD_SIZE; + constexpr int num_kv_heads = CMFLA_NUM_KV_HEADS; + constexpr int pa_block_sz = CMPA_BLOCK_SZ; + //# query [q_len, num_heads, S] + //# key [kv_len, num_heads, S] + //# value [kv_len, num_heads, S] + //# sparse_block_mask [num_heads, q_blocks, kv_blocks] + + auto batch = cm_group_id(0); + auto h = cm_group_id(1); + auto hkv = h / (num_heads/num_kv_heads); + auto wg_id = cm_group_id(2); // each work-group handles a sequence + auto wg_local_id = cm_local_id(2); + int local_size = cm_local_size(2); + int q_start_sg, kv_start, kv_seq_len, q_len_sg; + + // multiple work-groups are required to split a sequence, + // need to figure out which part of query-tokens to process + int wg_seq_len = local_size * q_step; + int past_q_lens = past_lens[0]; + kv_start = 0; + kv_seq_len = q_len + past_q_lens; + q_start_sg = (wg_id * local_size + wg_local_id) * q_step; + q_len_sg = q_step; + if (q_start_sg + q_len_sg > q_len) { + q_len_sg = q_len - q_start_sg; + } + + // qkv is fused + int kv_stop = kv_seq_len; + if constexpr (is_causal) { + /* + -------------------------------- + | | | | | + | 00 | | | | + | | | | | + -------------------------------- + | | | | | + | 10 | 11 | | | + | | | | | + --------------------------------- + | | | | | + | 20 | 21 | 22 | | + | | | | | + --------------------------------- + | | | | | + | 30 | 31 | 32 | 33 | + | | | | | + --------------------------------- + each grid can be [q_len_per_trunk, q_len_per_trunk]. + For each trunk, [q_len_per_trunk, past_q_lens] must be calculated. Such as: `20`,`21`. but for the 22, + causal mask optimization can be applied. different wgs would has different kv stop. + //todo:kv_stop is wg level, should we change to sg level? + */ + kv_stop = (wg_id + 1) * wg_seq_len + past_q_lens; + if (kv_stop > kv_seq_len) kv_stop = kv_seq_len; + } + + // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop); + // qkv fused + // constexpr uint num_total_heads = num_heads + num_kv_heads * 2; + // uint q_offset = (q_start*num_total_heads + h)*head_size; + // uint k_offset = (kv_start*num_total_heads + num_heads + hkv)*head_size; + // uint v_offset = (kv_start*num_total_heads + num_heads + num_kv_heads + hkv)*head_size; + + //Q/O[B, L, H, S] + uint q_offset = (q_start_sg*num_heads + h)*head_size; + uint o_offset = (q_start_sg*num_heads + h)*head_size; + + //K/V[block_num, kv_heads, block_sz, head_sz] + uint k_offset = hkv*head_size*pa_block_sz; + uint v_offset = hkv*head_size*pa_block_sz; + +#if SPARSE_BLOCK_SIZE > 1 + //# sparse_block_mask [num_heads, q_blocks, kv_blocks] + auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE; + int q_blocks = (q_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE; + int kv_blocks = (kv_seq_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE; + bool* block_mask_base = sparse_block_mask + (h * q_blocks + q_start_block)*kv_blocks; + // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, q_blocks, kv_blocks, sparse_block_mask, block_mask_base); +#endif + +#if USE_LSC == 1 + pa_kernel_lsc_prefetch( + wg_local_id, + q_start_sg, //q_start for SG, + kv_stop, + q_len_sg, //q_step, + kv_seq_len, //kv_len, not used for now + reinterpret_cast(query + q_offset), + reinterpret_cast(key + k_offset), + reinterpret_cast(value + v_offset), +#if SPARSE_BLOCK_SIZE > 1 + reinterpret_cast(block_mask_base), +#endif + reinterpret_cast(output + o_offset), + past_q_lens, + block_indices); +#else + static_assert(0); + +#endif +} + +} // NAMESPACE \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm new file mode 100644 index 00000000000000..8adea7ad4f9a22 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -0,0 +1,314 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +// xe-1:8, xe-2:16 +#if xe_arch==1 +#define REG_N 8 +#define USE_LSC_BLOCK_2D_DESC 0 +// #define KV_STEP 8 +#else +#define REG_N 16 +#define USE_LSC_BLOCK_2D_DESC 1 +// #define KV_STEP 16 +#endif + +#define SystolicDepth 8 +#define RepeatCount 1 +#define VNNI_WIDTH 2 +#define REG_K (SystolicDepth * VNNI_WIDTH) +#define REG_M RepeatCount + +#if 0 +#define HEADS_NUM +#define KV_HEADS_NUM +#define HEAD_SIZE +#define SCALE_FACTOR +#define KV_BLOCK_SIZE +#define KV_PARTITION_SIZE +#define Q_STEP +#define KV_STEP +#define WG_SIZE +#define XE_ARCH +#endif + +#define PARTITION_SUBBLOCK_NUM (KV_PARTITION_SIZE / KV_STEP) + +#define KV_SCALE_ZP_SIZE 0 // 4: scale/zp size + +extern "C" _GENX_MAIN_ void KERNEL_NAME( +// extern "C" _GENX_MAIN_ void cm_sdpa_2nd( + half* query [[type("svmptr_t")]], + half* key [[type("svmptr_t")]], + half* value [[type("svmptr_t")]], + int* past_lens [[type("svmptr_t")]], + int* block_indices [[type("svmptr_t")]], + int* block_indices_begins [[type("svmptr_t")]], + int* subsequence_begins [[type("svmptr_t")]], + half* output [[type("svmptr_t")]], + // half* mask [[type("svmptr_t")]], + float* lse [[type("svmptr_t")]], + // int* gws_subseq_mapping [[type("svmptr_t")]], + int q_len// 1 + ) { + //# batch=1, seq_num=1 or >1 + //# query [seq_idx, seq_num, head_num, head_size] + //# output[seq_idx, seq_num, head_num, head_size] + //# key [block_num, head_num, block_size, head_size] + [block_num, head_num, block_size, 4] (scale/zp) + //# value [block_num, head_num, block_size, head_size] + [block_num, head_num, block_size, 4] (scale/zp) + + //# KV_PARTITION_SIZE should be multiple of kv_block_size(KV_BLOCK_SIZE) + //# kv_len dimision will be split into multiple partitions, each WG process a partition + //# total_partitions_num = kv_len // KV_PARTITION_SIZE + //# GWS=[seq_num, num_heads, total_partitions_num] + //# LWS=[1, 1, 1] + + //# Each WG processes a partition, which is KV_PARTITION_SIZE long and multiple of KV_BLOCK_SIZE. + //# KV_BLOCK_SIZE can be 32/64/128/256, etc. + const auto seq_idx = cm_global_id(0); + const auto head_num_idx = cm_global_id(1); + const auto kv_head_num_idx = head_num_idx / (HEADS_NUM/KV_HEADS_NUM); + //# const auto wg_local_id = cm_local_id(2); + //# KV_PARTITION_SIZE --> EU thread + const auto wg_thread_id = cm_global_id(2); + const uint kv_partition_num = cm_group_count(2); + const uint partition_idx = cm_group_id(2); + + // # const uint subsequence_idx = gws_subseq_mapping[seq_idx]; + const uint subsequence_idx = seq_idx; + + //# const uint subsequence_begin = subsequence_begins[subsequence_idx]; + //# const uint subsequence_end = subsequence_begins[subsequence_idx + 1]; + const uint kv_len = past_lens[subsequence_idx] + 1; + const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * (KV_PARTITION_SIZE / KV_BLOCK_SIZE); + + if(partition_idx * KV_PARTITION_SIZE > kv_len) { + return; + } + const uint total_blocks_num = (kv_len + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE; + + //#TODO: int8 compression data + uint kv_pitch = HEAD_SIZE * sizeof(half); + //# fp16 data + //# uint qo_pitch = HEADS_NUM * HEAD_SIZE * sizeof(half); + + //# Load Q into register(as dpas-A tile) + matrix Qmat; + uint qo_offset = (seq_idx*HEADS_NUM*q_len + head_num_idx)*HEAD_SIZE; + for(int k = 0, ri = 0; k < HEAD_SIZE; k += REG_K, ri++) { + cm_svm_block_read((svmptr_t)(query + qo_offset + k), Qmat[ri].format()); + } + + // if(wg_thread_id==0 && head_num_idx == 0) { + // printf("Qmat loaded, wg_thread_id=%d\n", wg_thread_id); + // show(Qmat); + //} + + const uint per_kv_block_element_num = KV_BLOCK_SIZE * KV_HEADS_NUM * (HEAD_SIZE + KV_SCALE_ZP_SIZE / sizeof(half)); // 4: scale/zp + uint block_num = KV_PARTITION_SIZE / KV_BLOCK_SIZE; + + uint leftover_size = 0; + if(block_num > total_blocks_num - start_block_idx) { + block_num = total_blocks_num - start_block_idx; + leftover_size = kv_len - KV_PARTITION_SIZE * partition_idx; + leftover_size = KV_STEP * ((leftover_size + KV_STEP - 1) / KV_STEP); // round up to KV_STEP + } + + //# rS = Q @ Kt + //# PARTITION_SUBBLOCK_NUM * [REG_M, REG_K] * [REG_K, REG_N] = PARTITION_SUBBLOCK_NUM * [REG_M, REG_N] + matrix rS = 0; + // # each WI can process multiple blocks + for(uint block_idx = 0, ki = 0; block_idx < block_num; block_idx++) { + uint blk_indices = block_indices[start_block_idx + block_idx]; + uint kv_base_offset = blk_indices * per_kv_block_element_num + kv_head_num_idx * (per_kv_block_element_num / KV_HEADS_NUM); + uint kv_scale_zp_offset = kv_base_offset + KV_BLOCK_SIZE * HEAD_SIZE; // scale/zp offset + + // printf("seq_idx = %d, head_num_idx = %d, partition_idx = %d, start_block_idx = %d, block_idx = %d, blk_indices = %d, KV_PARTITION_SIZE = %d, KV_BLOCK_SIZE = %d, total_blocks_num = %d, seq_len = %d, kv_base_offset = %d\n", + // seq_idx, head_num_idx, partition_idx, start_block_idx, block_idx, blk_indices, KV_PARTITION_SIZE, KV_BLOCK_SIZE, total_blocks_num, seq_len, kv_base_offset); + + #if USE_LSC_BLOCK_2D_DESC + //# vector load cannot be used for block_2d_desc + //# note: candidate template ignored: deduced type 'details::Block2DRefTy' (aka 'vector_ref') of 1st parameter + //# b2dK reinterpret as 32bit(DWORD) for transposed load(combined with VNNI) + lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); + //printf("b2dK: kv_base_offset = %d, KV_BLOCK_SIZE = %d, HEAD_SIZE = %d, kv_pitch = %d, blk_indices = %d, block_idx = %d, start_block_idx = %d\n", + // kv_base_offset, KV_BLOCK_SIZE, HEAD_SIZE, kv_pitch, blk_indices, block_idx, start_block_idx); + #else + uint kv_offset = kv_base_offset; + uint kv_stride = HEAD_SIZE; + uint kv_x0 = 0, kv_y0 = 0; + uint kv_x1 = HEAD_SIZE*sizeof(half); + uint kv_y1 = KV_BLOCK_SIZE; + #endif + + uint kv_pos_end = KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && leftover_size > 0) { + kv_pos_end = leftover_size % KV_BLOCK_SIZE; + } + for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += KV_STEP, ki++) { + auto rSvec = rS[ki].format(); + uint kv_offset_y = kv_pos; + + #pragma unroll + for(int k = 0, ri = 0; k < HEAD_SIZE/2; k += REG_K/2, ri ++ ) { + matrix Kt; + #if USE_LSC_BLOCK_2D_DESC + //# Load Kt into register & pack as VNNI(as dpas-B tile) + //# DWORD transposed load == (transposed + VNNI) load + b2dK.set_block_x(k); + cm_load(Kt.format(), b2dK.set_block_y(kv_offset_y)); + #else + matrix temp; + uint cur_kv_offset = kv_offset + kv_offset_y * kv_stride + k * 2;// uint --> half + #pragma unroll + for(int kk = 0; kk < REG_N; kk++) { + cm_svm_block_read((svmptr_t)(key + cur_kv_offset + kk * kv_stride), temp[kk].format()); + } + #if XE_ARCH==1 + Transpose_8x8(temp.select<8,1,8,1>(0,0), Kt.format().select<8,1,8,1>(0,0)); + #else + Transpose_8x8(temp.select<8,1,8,1>(0,0), Kt.format().select<8,1,8,1>(0,0)); + Transpose_8x8(temp.select<8,1,8,1>(8,0), Kt.format().select<8,1,8,1>(0,8)); + #endif + #endif + rSvec = cm_dpas( + rSvec, + Kt.format(), + Qmat[ri].format()); + } + } + } + + // printf("rS:\n"); + // show(rS); + + // online softmax + float cur_sum = 0.0f; + float cur_lse = 0.0f; + #if XE_ARCH==1 + matrix Pmat = 0; + #else + matrix Pmat = 0; + #endif + { + //# Load Mask into register + // matrix MaskMat; + // uint mask_offset = seq_idx * q_len * kv_len + wg_thread_id * KV_PARTITION_SIZE; + // cm_svm_block_read((svmptr_t)(mask + mask_offset), MaskMat.format()); + + rS = cm_mul(rS, (float)SCALE_FACTOR); // convert scale_factor into (float), or it will be promoted to double + //rS = cm_add(rS, MaskMat); + + // compute lse + constexpr float log2e = 1.4426950408889634f; + vector rS_exp = cm_exp(rS.format()*log2e); + cur_lse += cm_sum(rS_exp); + + // compute row_max + auto rSv = rS.format(); + float row_max = rSv[0]; + for(int r = 1; r < rSv.n_elems(); r++) + row_max = cm_max(row_max, rSv[r]); + + // compute P = exp(rS - row_max) + #if XE_ARCH==1 + Pmat= cm_exp((rS.format() - row_max)*log2e); + #else + Pmat= cm_exp((rS - row_max)*log2e); + #endif + + // compute row sum of P + auto rPv = Pmat.format(); + cur_sum = cm_sum(rPv[0]); + } + + //if(wg_thread_id==0) { + // printf("Pmat:\n"); + // show(Pmat); + //} + + //# rO = P * V + matrix Omat = 0; + for(uint block_idx = 0, ki = 0; block_idx < block_num; block_idx++) { + uint blk_indices = block_indices[start_block_idx + block_idx]; + uint kv_base_offset = blk_indices * per_kv_block_element_num + kv_head_num_idx * (per_kv_block_element_num / KV_HEADS_NUM); + uint kv_scale_zp_offset = kv_base_offset + KV_BLOCK_SIZE * HEAD_SIZE; // scale/zp offset + + #if USE_LSC_BLOCK_2D_DESC + //# vector load cannot be used for block_2d_desc + //# note: candidate template ignored: deduced type 'details::Block2DRefTy' (aka 'vector_ref') of 1st parameter + lsc::block_2d_desc b2dV(value + kv_base_offset, KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); + #else + uint kv_offset = kv_base_offset; + uint kv_stride = HEAD_SIZE; + uint kv_x0 = 0, kv_y0 = 0; + uint kv_x1 = HEAD_SIZE*sizeof(half); + uint kv_y1 = KV_BLOCK_SIZE; + #endif + + uint kv_pos_end = KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && leftover_size > 0) { + kv_pos_end = leftover_size % KV_BLOCK_SIZE; + } + for(int kv_pos =0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { + uint kv_offset_y = kv_pos; + #pragma unroll + for(int k = 0, ri = 0; k < HEAD_SIZE; k += REG_N, ri ++ ) { + // Load V into register & pack as VNNI(as dpas-B tile) + matrix Vmat; + #if USE_LSC_BLOCK_2D_DESC + b2dV.set_block_x(k); + cm_load(Vmat[0].format(), b2dV.set_block_y(kv_offset_y)); + #else + matrix temp; + uint cur_kv_offset = kv_offset + kv_offset_y * kv_stride + k; + #pragma unroll + for(int kk = 0; kk < REG_K; kk++) { + cm_svm_block_read((svmptr_t)(value + cur_kv_offset + kk * kv_stride), temp[kk].format()); + } + auto Vref = Vmat[0].format(); + Vref.select(0, 0) = temp.select(0, 0); + Vref.select(0, 1) = temp.select(1, 0); + #endif + Omat[ri] = cm_dpas( + Omat[ri], + Vmat[0].format(), + Pmat[ki].format()); + } + } + } + + //if(wg_thread_id==0) { + // printf("Omat:\n"); + // show(Omat); + //} + + //# save Output + matrix cur_O_f16; + uint o_offset = seq_idx * kv_partition_num * HEADS_NUM * HEAD_SIZE + kv_partition_num * head_num_idx * HEAD_SIZE + wg_thread_id * HEAD_SIZE; + float div_cur_sum = 1.0/cur_sum; + #pragma unroll + for(int k = 0, ri=0; k < HEAD_SIZE; k += REG_N, ri++) { + auto cO = Omat[ri].format(); + #if XE_ARCH==1 + cur_O_f16= cm_mul(cO, div_cur_sum); + #else + cur_O_f16= cm_div_ieee(cO, cur_sum); + #endif + cm_svm_block_write((svmptr_t)(output + o_offset + k),cur_O_f16.format()); + } + uint lse_offset = seq_idx * HEADS_NUM * kv_partition_num + head_num_idx * kv_partition_num + wg_thread_id; + lse[lse_offset] = cur_lse; +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm new file mode 100644 index 00000000000000..001eb39ad1435c --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -0,0 +1,54 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#if 0 +#define HEADS_NUM +#define HEAD_SIZE +#define REDUCE_SPLIT_SIZE +#endif + +//cm_sdpa_2nd_reduce +extern "C" _GENX_MAIN_ void KERNEL_NAME( +// extern "C" _GENX_MAIN_ void cm_sdpa_2nd_reduce( + half* input [[type("svmptr_t")]], // + half* output [[type("svmptr_t")]], + float* lse [[type("svmptr_t")]], + int kv_partition_num + ) { + auto batch = cm_global_id(0); + auto head = cm_global_id(1); + auto offset = cm_group_id(2) * REDUCE_SPLIT_SIZE; + const int total_partition_num = (kv_partition_num * HEADS_NUM); + + // load lse + #if 0 + uint lse_offset = batch * total_partition_num + head * kv_partition_num; + vector lse_vec; + cm_svm_block_read((svmptr_t)(lse + lse_offset), lse_vec.format()); + float total_lse = cm_sum(lse_vec); + #else + float total_lse = 0.0; + uint lse_offset = batch * total_partition_num + head * kv_partition_num; + float* lse_vec = lse + lse_offset; + #pragma unroll + for(int k = 0; k < kv_partition_num; k ++) { + total_lse += lse_vec[k]; + } + #endif + + // load input, total_partition_num = head_nums * kv_partition_num; + matrix out_mat = 0; + matrix data_mat; + uint input_offset = batch * total_partition_num * HEAD_SIZE + head * kv_partition_num * HEAD_SIZE + offset; + #pragma unroll + for(int k = 0; k < kv_partition_num; k ++) { + cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); + input_offset += HEAD_SIZE; + out_mat += cm_mul(data_mat, (float)(lse_vec[k]/total_lse)); + } + + // write output + uint output_offset = batch * HEADS_NUM * HEAD_SIZE + head * HEAD_SIZE + offset; + cm_svm_block_write((svmptr_t)(output + output_offset),out_mat.format()); + } \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp new file mode 100644 index 00000000000000..fc3c18e39f3dfe --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -0,0 +1,159 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "paged_attention.hpp" +#include "paged_attention_gen.hpp" + +#include +#include +#include +#include + +#include "primitive_cm_base.hpp" +#include "common_utils/kernel_generator_base.hpp" +#include "common_utils/jitter.hpp" +#include "intel_gpu/graph/kernel_impl_params.hpp" +#include "intel_gpu/primitives/paged_attention.hpp" +#include "kv_cache_inst.h" +#include "openvino/core/partial_shape.hpp" +#include "paged_attention_inst.h" +#include "primitive_inst.h" + +namespace ov::intel_gpu::cm { + +class PagedAttentionCmImpl : public PrimitiveImplCM { +public: + DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::cm::PagedAttentionCmImpl) + + Stage::Ptr kv_cache_update = make_stage(); + Stage::Ptr pa_single_token = make_stage(); + Stage::Ptr pa_single_token_finalization = make_stage(); + Stage::Ptr pa_multi_token = make_stage(); + + PagedAttentionCmImpl(): PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) {} + explicit PagedAttentionCmImpl(const kernel_impl_params& params) : PagedAttentionCmImpl() { + const auto desc = params.typed_desc(); + + std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::PagedAttentionCmImpl()" << std::endl; + add_stage(kv_cache_update, params); + add_stage(pa_single_token, params); + add_stage(pa_single_token_finalization, params); + add_stage(pa_multi_token, params); + } + + void update_rt_params(const primitive_inst& instance) override { + update_stages_flags(instance); + if (m_rt_params == nullptr) { + m_rt_params = std::make_unique(); + } + std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::update_rt_params()" << std::endl; + const auto& params = *instance.get_impl_params(); + auto rt_params = static_cast(m_rt_params.get()); + const auto& desc = params.typed_desc(); + + const auto max_context_len = get_max_context_len(params); + rt_params->max_context_len = max_context_len; + rt_params->partition_size = get_partition_size(); + rt_params->num_of_partitions = ceil_div(max_context_len, rt_params->partition_size); + rt_params->stage = get_paged_attention_stage(params); + + std::cout << " max_context_len: " << rt_params->max_context_len << " partition_size: " << rt_params->partition_size + << " num_of_partitions: " << rt_params->num_of_partitions << ", stage: " << static_cast(rt_params->stage) << std::endl; + } + + // update impl_parameter and rt_parameter + void update(primitive_inst& inst, const kernel_impl_params& impl_params) override { + PrimitiveImplCM::update(inst, impl_params); + update_rt_params(inst); + } + + event::ptr execute(const std::vector& events, primitive_inst& instance) override { + const auto& params = *instance.get_impl_params(); + const auto desc = params.typed_desc(); + + update_stages_flags(instance); + auto rt_params = static_cast(m_rt_params.get()); + assert(rt_params != nullptr); + + std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::execute(): stage = " << static_cast(rt_params->stage) << std::endl; + std::vector res_event = events; + res_event = {execute_stage(res_event, instance, kv_cache_update)}; + + if (rt_params->stage == PagedAttentionStage::PREFILL || rt_params->stage == PagedAttentionStage::MIXED) { + res_event = {execute_stage(res_event, instance, pa_multi_token)}; + } else if (rt_params->stage == PagedAttentionStage::GENERATE) { + res_event = {execute_stage(res_event, instance, pa_single_token)}; + res_event = {execute_stage(res_event, instance, pa_single_token_finalization)}; + } + return res_event[0]; + } + + bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override { + const auto stage = get_paged_attention_stage(impl_params); + + // In case of MIXED mode execution Paged Attention may require dispatch data update and internal + // buffers reallocation even if the input shapes haven't been changed. Therefore, check the current execution + // mode and update parameters if needed + return stage == PagedAttentionStage::MIXED; + } + + [[nodiscard]] std::vector get_internal_buffer_descs(const kernel_impl_params& params) const override { + std::vector internal_buffers; + + const auto desc = params.typed_desc(); + const auto indexes_dt = ov::element::u8; + const auto element_size = 4; // 4 bytes + auto stage = PagedAttentionStage::UNKNOWN; + auto rt_params = static_cast(m_rt_params.get()); + + size_t partition_size = PA_KV_CACHE_BLOCK_SIZE; + size_t num_of_partitions = 1; + if (rt_params != nullptr && rt_params->num_of_partitions != 0) { + stage = rt_params->stage; + partition_size = rt_params->partition_size; + num_of_partitions = rt_params->num_of_partitions; + } else { + stage = get_paged_attention_stage(params); + const auto max_context_len = get_max_context_len(params); + partition_size = get_partition_size(); + num_of_partitions = ceil_div(max_context_len, partition_size); + } + std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::get_internal_buffer_descs(): stage = " << static_cast(stage) + << " partition_size: " << partition_size << " num_of_partitions: " << num_of_partitions << std::endl; + if (stage == PagedAttentionStage::GENERATE) { + const auto& input = params.input_layouts[0]; + const int64_t total_tokens = input.get_partial_shape()[0].get_length(); + auto buf_elements_count = static_cast(total_tokens * desc->heads_num * num_of_partitions); + auto tmp_out_elements_count = static_cast(total_tokens * desc->heads_num * desc->v_head_size * num_of_partitions); + + internal_buffers.emplace_back(tmp_out_elements_count * element_size, indexes_dt); // 0: intermediate partition output + internal_buffers.emplace_back(buf_elements_count * element_size, indexes_dt); // 1: softmax exp_sums + + std::cout << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * element_size << " exp_sums=" << buf_elements_count * element_size + << std::endl; + } else { + internal_buffers.emplace_back(16, indexes_dt); // 0: intermediate partition output + internal_buffers.emplace_back(16, indexes_dt); // 1: softmax exp_sums + } + + return internal_buffers; + } + + [[nodiscard]] std::unique_ptr clone() const override { + return make_deep_copy(this); + } +}; + +std::unique_ptr PagedAttentionImplementationManager::create_impl(const program_node& node, const kernel_impl_params& params) const { + assert(node.is_type()); + try { + return std::make_unique(params); + } catch (const std::exception& e) { + OPENVINO_THROW("Failed to create PagedAttentionCmImpl: ", e.what()); + } +} + +} // namespace ov::intel_gpu::cm +// BIND_BINARY_BUFFER_WITH_TYPE(cldnn::paged_attention) +BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::cm::PagedAttentionCmImpl) \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp new file mode 100644 index 00000000000000..e5909242813758 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include "program_node.h" +#include "intel_gpu/runtime/layout.hpp" +#include "registry/implementation_manager.hpp" +#include "paged_attention_inst.h" + +using namespace cldnn; // TODO: Remove once namespaces are aligned + +namespace ov::intel_gpu::cm { + +struct PagedAttentionImplementationManager : public ImplementationManager { + OV_GPU_PRIMITIVE_IMPL("cm::paged_attention::opt") + explicit PagedAttentionImplementationManager(shape_types shape_type, ValidateFunc vf = nullptr) + : ImplementationManager(impl_types::cm, shape_type, std::move(vf)) {} + [[nodiscard]] std::unique_ptr create_impl(const program_node& node, const kernel_impl_params& params) const override; + [[nodiscard]] bool validate_impl(const program_node& node) const override { + static constexpr std::array supported_q_types = { + ov::element::f16, + }; + static constexpr std::array supported_kv_types = { + ov::element::f16, + // ov::element::i8, + }; + + auto& engine = node.get_program().get_engine(); + const auto& config = node.get_program().get_config(); + const auto& info = engine.get_device_info(); + // CM optimized for systolic-array architectures + if (!check_cm_jit_support(engine, config) || !info.supports_immad || !config.get_use_cm()) { + std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + return false; + } + + const auto& q_layout = node.get_input_layout(0); + const auto& k_layout = node.get_input_layout(1); + const auto& v_layout = node.get_input_layout(2); + const auto& out_layout = node.get_output_layout(0); + if (!everyone_is(format::bfyx, q_layout.format, k_layout.format, v_layout.format, out_layout.format)) { + std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + return false; + } + + if (!one_of(k_layout.data_type, supported_kv_types) || !one_of(v_layout.data_type, supported_kv_types)) { + std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + return false; + } + + if (!one_of(q_layout.data_type, supported_q_types) || !one_of(out_layout.data_type, supported_q_types)) { + std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + return false; + } + + std::cout << "ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - true" << std::endl; + return true; + } +}; +} // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp new file mode 100644 index 00000000000000..8b271270b5f9c4 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -0,0 +1,492 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "paged_attention_gen.hpp" + +#include +#include +#include +#include + +#include "intel_gpu/primitives/paged_attention.hpp" +#include "intel_gpu/runtime/memory.hpp" +#include "openvino/core/partial_shape.hpp" +#include "paged_attention_inst.h" +#include "primitive_cm_base.hpp" +#include "primitive_inst.h" + +namespace ov::intel_gpu::cm { + +using namespace ov; +using namespace ov::intel_gpu::ocl; +using namespace cldnn; +namespace { +constexpr size_t WG_SIZE = 16; +constexpr size_t reduce_split_step = 16; +} // namespace + +// This function returns the kv_step and kv_split_len based on the architecture. +// return {kv_step, kv_split_len} +inline std::pair get_kv_split_size(size_t arch) { + if (arch == 1) { + return {8, 32}; // For Xe1 + } else if (arch == 2) { + return {16, 32}; // For Xe2 + } + OPENVINO_ASSERT(false, "Unsupported architecture for KV split size"); + return {0, 0}; // Fallback case, should not be reached +} + +inline size_t get_q_step(size_t arch, bool is_single_token = false) { + if (arch == 1) { + return is_single_token ? 1 : 8; // For Xe1 + } else if (arch == 2) { + return is_single_token ? 1 : 32; // For Xe2 + } + OPENVINO_ASSERT(false, "Unsupported architecture for Q step"); + return 0; // Fallback case, should not be reached +} + +inline size_t get_kv_len(const RuntimeParams& params, const PagedAttentionStage& stage) { + if (stage == PagedAttentionStage::PREFILL) { + auto key_shape = params.input_layouts[PagedAttentionInputIdx::KEY].get_shape(); + const size_t kv_len = key_shape[key_shape.size() - 2]; + return kv_len; + } else { + // key_cache shape = [block_num, head_num, block_size(128), head_size] + auto key_cache_shape = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE].get_shape(); + const size_t kv_len = key_cache_shape[0] * key_cache_shape[2]; + return kv_len; + } + OPENVINO_ASSERT(false, "Unsupported PagedAttentionStage for get_kv_len"); + return 0; // Fallback case, should not be reached +} + +inline size_t get_aligned_kv_len(const size_t kv_len) { + return (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE * PA_KV_CACHE_BLOCK_SIZE; +} + +int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size = 16) { + // Since at prefill stage Q, K, V inputs may contain multiple sequences with arbitrary + // target sequence lengths each (shape is [sequences_num * target_seq_len, num_heads * head_size]), + // to apply blocking to the first dimension (target_seq_len of each sequence), we need to calculate aligned total + // target sequence length for proper kernel dispatching + // For instance, if input contains two sequences with 35 and 28 sequence lengths each, + // the Q, K, V inputs at prefill stage will have shapes [35 + 28, num_heads * head_size]; considering kernel's + // target_seq_len_block_size equals 16, we need to launch kernel instances for the following ranges: + // [0, 15], [16, 31], [32, 34], [35, 50], [51, 62], so aligned target_seq_len_block_size should be 5 * 16 = 80, + // and 5 kernels instances should be launched (for each range, some of them containing leftovers) + // + // In general, to obtain length for each sequence, we have to parse subsequence_begins input, + // which contains begin and end indexes for each sequence (for above example it will contain three values: {0, 35, 63}) + // However, as long as kernel's target_seq_len_block_size matches with vLLM's block_size, + // we can reuse block_indices_shape[0] size to determine total aligned sequences length size, avoiding + // memory access at runtime, because vLLM internally uses similar logic to configure blocks for KV cache + + auto calculate_aligned_seq_len = [&]() { + const auto& input_mem = impl_param.memory_deps; + const auto subsequence_begins_mem = input_mem.at(PagedAttentionInputIdx::SUBSEQUENCE_BEGINS); + mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, *impl_param.strm); + + auto aligned_seq_len = 0; + if (stage == PagedAttentionStage::MIXED) { + const auto past_lens_mem = input_mem.at(PagedAttentionInputIdx::PAST_LENS); + mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); + + for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { + auto past_len = past_lens_mem_lock[i]; + auto seq_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; + + // Since in MIXED execution mode the present KV-cache can be appended to the past KV-cache at any offset inside block, + // to ensure proper alignment and update_kv_cache kernel scheduling, we need to account for the number of unaligned tokens + // in the first block + // For example, if we need to store values in the following slots: + // + // block0: |O|O|O|O|O|O|O|O|O|O|O|O|U|U|U|U| + // block1: |U|U|U|U|U|U|U|U|U|U|U|U|U|U|U|U| + // block2: |U|U|U|U|U|U|E|E|E|E|E|E|E|E|E|E| + // Where O - occupied slots, U - currently beeing updated slots, E - empty slots + // + // We need to schedule 3 update_kv_cache operations: + // - For ranges of block0: [12-15] + // - For ranges of block1: [0-15] + // - For ranges of block2: [0-5] + // + // Therefore, consider an additional increment of aligned_seq_len to properly process all the blocks + + auto occupied_slots_num = past_len % target_seq_len_block_size; + if (past_len != 0 && seq_length + occupied_slots_num > target_seq_len_block_size) { + aligned_seq_len += target_seq_len_block_size; + seq_length -= target_seq_len_block_size - occupied_slots_num; + } + + aligned_seq_len += align_to(seq_length, target_seq_len_block_size); + } + } else { + for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { + auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; + aligned_seq_len += align_to(prompt_length, target_seq_len_block_size); + } + } + + return aligned_seq_len; + }; + + int64_t aligned_seq_len = 0; + if (stage == PagedAttentionStage::PREFILL) { + const auto desc = impl_param.typed_desc(); + if (static_cast(paged_attention::block_size) == target_seq_len_block_size) { + const auto& block_indices_ps = impl_param.get_input_layout(PagedAttentionInputIdx::BLOCK_INDICES).get_partial_shape(); + + aligned_seq_len = block_indices_ps[0].get_length() * target_seq_len_block_size; + } else { + aligned_seq_len = calculate_aligned_seq_len(); + } + } else { + aligned_seq_len = calculate_aligned_seq_len(); + } + + return aligned_seq_len; +} + +size_t get_partition_size() { + // size_t k_partition_blok_num = (kv_len + 8191) / 8192; + // if (k_partition_blok_num < 1) + // k_partition_blok_num = 1; + const size_t k_partition_blok_num = 1; + return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; +} + +size_t get_partition_num(const size_t kv_len) { + const size_t partition_size = get_partition_size(); + const size_t partition_num = (kv_len + partition_size - 1) / partition_size; + + return partition_num; +} + +// max_context_len = max(past_lens + prompt_lens) +size_t get_max_context_len(const kernel_impl_params& params) { + const auto& input_mem = params.memory_deps; + const auto max_context_len = input_mem.at(PagedAttentionInputIdx::MAX_CONTEXT_LEN); + mem_lock max_context_len_mem_lock(max_context_len, *params.strm); + const auto paged_attention_max_len = max_context_len_mem_lock[0]; + return paged_attention_max_len; +} + +size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx) { + const auto& input_mem = params.memory_deps; + const auto past_len = input_mem.at(PagedAttentionInputIdx::PAST_LENS); + mem_lock past_len_mem_lock(past_len, *params.strm); + const auto paged_attention_past_len = past_len_mem_lock[seq_idx]; + return paged_attention_past_len; +} + +PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) { + const auto& query_shape = impl_param.get_input_layout(PagedAttentionInputIdx::QUERY).get_partial_shape(); + const auto& past_lens_shape = impl_param.get_input_layout(PagedAttentionInputIdx::PAST_LENS).get_partial_shape(); + + if (query_shape.is_static() && past_lens_shape.is_static()) { + if (query_shape[0].get_length() == past_lens_shape[0].get_length()) { + return PagedAttentionStage::GENERATE; + } + + const auto& memory_deps = impl_param.memory_deps; + const auto past_lens_mem = memory_deps.at(PagedAttentionInputIdx::PAST_LENS); + mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); + + const auto past_lens_size = past_lens_mem_lock.size(); + for (size_t i = 0; i < past_lens_size; i++) { + if (past_lens_mem_lock[i] != 0) { + return PagedAttentionStage::MIXED; + } + } + return PagedAttentionStage::PREFILL; + } + return PagedAttentionStage::UNKNOWN; +} + +JitConstants PagedAttentionGeneratorBase::get_jit_constants(const kernel_impl_params& params) const { + auto jit = KernelGenerator::get_jit_constants(params); + jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); + std::cout << "PagedAttentionGeneratorBase::get_jit_constants: " << get_entry_point(params) << std::endl; + + // auto desc = params.typed_desc(); + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + jit.make("XE_ARCH", xe_arch); + + auto split_size = get_kv_split_size(xe_arch); + jit.make("KV_STEP", split_size.first); + + jit.make("WG_SIZE", WG_SIZE); + jit.make("CAUSAL_MASK", 1); + return jit; +} + +Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_impl_params& params) const { + const auto desc = params.typed_desc(); + + Arguments args; + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // query + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY}); // key + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE}); // value + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // past_lens + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block_indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block_indices_begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins +# if PA_SPARSE_BLOCK_SIZE > 1 + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SPARSE_BLOCK_MASK}); // sparse_block_mask +# endif + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); + + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len + return args; +} + +JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + const auto desc = params.typed_desc(); + const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)); + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + + jit.make("CMFLA_NUM_HEADS", desc->heads_num); + jit.make("CMFLA_NUM_KV_HEADS", desc->kv_heads_num); + jit.make("CMFLA_HEAD_SIZE", desc->k_head_size); + jit.add(make_jit_constant("CMFLA_SCALE_FACTOR", scale_factor)); + jit.make("CMFLA_IS_CAUSAL", 1); + jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE); + jit.make("SPARSE_BLOCK_SIZE", PA_SPARSE_BLOCK_SIZE); + jit.make("Q_STEP", get_q_step(xe_arch, true)); + // for (auto& it : jit) { + // std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl; + // } + // std::cout << std::endl; + return jit; +} + +DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + auto& wgs = kd.params.workGroups; + auto& scalars = kd.params.scalars; + auto desc = params.typed_desc(); + // auto rtp = static_cast(rt_params); + // assert(rt_params != nullptr); + const size_t heads_num = desc->heads_num; + + auto out_shape = params.output_layouts[0].get_shape(); + const size_t batch = out_shape.size() < 4 ? 1 : out_shape[0]; + const size_t q_len = out_shape[0]; + + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + const size_t q_step = get_q_step(xe_arch, false); + const size_t wg_seq_len = WG_SIZE * q_step; + const size_t wg_count = align_to(q_len, wg_seq_len) / wg_seq_len; + + wgs.global = {batch, heads_num, wg_count * WG_SIZE}; + wgs.local = {1, 1, WG_SIZE}; + + std::vector scaler_value = {q_len}; + scalars.resize(scaler_value.size()); + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); + auto desc = params.typed_desc(); + const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)); + const size_t kv_partition_size = get_partition_size(); + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + + jit.make("KV_PARTITION_SIZE", kv_partition_size); + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + jit.add(make_jit_constant("SCALE_FACTOR", scale_factor)); + jit.make("HEAD_SIZE", desc->k_head_size); + jit.make("HEADS_NUM", desc->heads_num); + jit.make("KV_HEADS_NUM", desc->kv_heads_num); + jit.make("Q_STEP", get_q_step(xe_arch, true)); + + return jit; +} + +Arguments PagedAttentionGeneratorSingleToken::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + const auto desc = params.typed_desc(); + // const auto has_scale_input = !desc->scale_val.has_value(); + const auto has_scores_output = params.output_layouts.size() > 1; + + OPENVINO_ASSERT(!has_scores_output, "[GPU][CM] PagedAttentionGeneratorSingleToken with scores output is not supported yet"); + + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // keys cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // values cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // past lens + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence begins + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); // partition output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); // lse output + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len==1 + // args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // kv_partition_num + + return args; +} + +DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + const auto desc = params.typed_desc(); + auto rtp = static_cast(rt_params); + + assert(rt_params != nullptr); + + const size_t batch = params.input_layouts[0].get_partial_shape()[0].get_length(); + const size_t heads_num = desc->heads_num; + const size_t partition_num = rtp->num_of_partitions; // get_partition_num(rtp->max_context_len); + wgs.global = {batch, heads_num, partition_num}; + wgs.local = {1, 1, 1}; + + // generate stage: q_len=1 + auto& scalars = kd.params.scalars; + std::vector scaler_value = {1}; + scalars.resize(scaler_value.size()); + + // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " + // << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", kv_len: " << kv_len << std::endl; + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +JitConstants PagedAttentionGeneratorSingleTokenFinalization::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + const auto desc = params.typed_desc(); + + jit.make("REDUCE_SPLIT_SIZE", reduce_split_step); + jit.make("HEAD_SIZE", desc->k_head_size); + jit.make("HEADS_NUM", desc->heads_num); + return jit; +} + +Arguments PagedAttentionGeneratorSingleTokenFinalization::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + const auto has_scores_output = params.output_layouts.size() > 1; + OPENVINO_ASSERT(!has_scores_output, "[GPU][CM] PagedAttentionGeneratorSingleTokenFinalization with scores output is not supported yet"); + + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); // partition data + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); // lse + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // kv_partition_num + + return args; +} + +DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + const auto desc = params.typed_desc(); + auto rtp = static_cast(rt_params); + + assert(rt_params != nullptr); + + const size_t batch = params.input_layouts[0].get_partial_shape()[0].get_length(); + const size_t heads_num = desc->heads_num; + const size_t head_size = desc->k_head_size; + wgs.global = {batch, heads_num, head_size / reduce_split_step}; + wgs.local = {1, 1, 1}; + + auto& scalars = kd.params.scalars; + const size_t partition_num = rtp->num_of_partitions; // get_partition_num(rtp->max_context_len); + std::vector scaler_value = {partition_num}; + scalars.resize(scaler_value.size()); + + // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " + // << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", kv_len: " << kv_len << std::endl; + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + + const auto desc = params.typed_desc(); + jit.make("KV_HEADS_NUM", desc->kv_heads_num); + jit.make("K_HEAD_SIZE", desc->k_head_size); + jit.make("V_HEAD_SIZE", desc->v_head_size); + jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size); + jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size); + jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + + return jit; +} + +Arguments PagedAttentionGeneratorKVCacheUpdate::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE}); // keys cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // values cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // keys cache + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // key_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // value_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // batch_size_in_sequences + return args; +} + +DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + const auto desc = params.typed_desc(); + // auto rtp = static_cast(rt_params); + + const size_t kv_len = get_max_context_len(params); + const size_t kv_heads_num = desc->kv_heads_num; + const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; + + wgs.global = {1, kv_heads_num, wg_count * WG_SIZE}; + wgs.local = {1, 1, WG_SIZE}; + + auto& scalars = kd.params.scalars; + size_t key_pitch = desc->k_head_size * kv_heads_num; + size_t value_pitch = desc->v_head_size * kv_heads_num; + size_t batch_size_in_sequences = kv_len; + std::vector scaler_value = {key_pitch, value_pitch, batch_size_in_sequences}; + scalars.resize(scaler_value.size()); + + // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " + // << "batch: " << batch << ", heads_num: " << heads_num << ", split_num: " << split_num << ", kv_len: " << kv_len << std::endl; + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +} // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp new file mode 100644 index 00000000000000..a3149054ec6617 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -0,0 +1,96 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "../ocl_v2/utils/jitter.hpp" +#include "common_utils/jitter.hpp" +#include "intel_gpu/graph/kernel_impl_params.hpp" +#include "intel_gpu/primitives/paged_attention.hpp" +#include "intel_gpu/runtime/layout.hpp" +#include "openvino/core/type.hpp" +#include "program_node.h" +#include "registry/implementation_manager.hpp" +#include "utils/kernel_generator.hpp" + +using namespace cldnn; // TODO: Remove once namespaces are aligned + +namespace ov::intel_gpu::cm { + +// constexpr auto get_pa_build_options() { +// return " -cmc -Qxcm_register_file_size=256 -mdump_asm -g2 "; +// } +constexpr auto get_pa_build_options() { + return " -cmc -Qxcm_register_file_size=256"; +} + +// BLOCK_SIZE can be 16/32/64/128/256 +#define PA_KV_CACHE_BLOCK_SIZE 16 +// sparse attention block size is set to 1 to disable sparse attention support in CM kernels +#define PA_SPARSE_BLOCK_SIZE 1 + + +enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; +struct PagedAttentionRuntimeParams : public ImplRuntimeParams { + PagedAttentionStage stage; + size_t num_of_partitions; + size_t partition_size; + size_t max_context_len; + size_t paged_attention_aligned_seq_len; +}; + +int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size); +PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param); +size_t get_max_context_len(const kernel_impl_params& params); +size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); +size_t get_partition_size(); +size_t get_partition_num(const size_t kv_len); + +class PagedAttentionGeneratorBase : public KernelGenerator { +public: + explicit PagedAttentionGeneratorBase(std::string_view kernel_name, std::string_view stage_suffix = "_cm") : KernelGenerator(kernel_name, stage_suffix) {} + [[nodiscard]] std::string get_build_options(const RuntimeParams& params) const override { + return KernelGenerator::get_build_options(params) + get_pa_build_options(); + } + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; +}; + +class PagedAttentionGeneratorKVCacheUpdate : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorKVCacheUpdate() : PagedAttentionGeneratorBase("pa_kv_cache_update_ref") {} + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class PagedAttentionGeneratorMultiToken : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorMultiToken() : PagedAttentionGeneratorBase("pa_multi_token") {} + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class PagedAttentionGeneratorSingleToken : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorSingleToken() : PagedAttentionGeneratorBase("pa_single_token") {} + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class PagedAttentionGeneratorSingleTokenFinalization : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorSingleTokenFinalization() : PagedAttentionGeneratorBase("pa_single_token_finalization") {} + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +} // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp index 8fccefbc9e5eae..fb5cf4631bace7 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp @@ -25,9 +25,13 @@ struct PagedAttentionOpt : public ImplementationManager { ov::element::f16, }; static constexpr std::array supported_kv_types = { + #if ENABLE_PA_CM_PATH + ov::element::i8, + #else ov::element::f32, ov::element::f16, ov::element::i8, + #endif }; const auto& q_layout = node.get_input_layout(0); diff --git a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp index c1fc982350054e..738cf4c9e59d95 100644 --- a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp +++ b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp @@ -8,6 +8,7 @@ #if OV_GPU_WITH_OCL #include "impls/ocl_v2/sdpa/paged_attention_opt.hpp" + #include "impls/cm/paged_attention.hpp" #endif namespace ov { @@ -18,6 +19,7 @@ using namespace cldnn; const std::vector>& Registry::get_implementations() { static const std::vector> impls = { OV_GPU_CREATE_INSTANCE_OCL(ocl::PagedAttentionOpt, shape_types::any) + OV_GPU_CREATE_INSTANCE_OCL(cm::PagedAttentionImplementationManager, shape_types::any) }; return impls; diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 3675d481b935de..f8339cb9e46e51 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1181,6 +1181,30 @@ const auto DISABLE_FA_V2 = true; INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ +#if ENABLE_PA_CM_PATH + /* without scores output, dynamic input query paddings */ + paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + + paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token +#else /* with scores output, use SnapKV */ paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{36, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token @@ -1289,4 +1313,5 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{5, 10}}, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{5, 10}}, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 64, 64, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token +#endif })); From 435a7ac0a1a9e5120468274d0faba03d09ae04db Mon Sep 17 00:00:00 2001 From: "river.li" Date: Sun, 31 Aug 2025 11:50:05 +0800 Subject: [PATCH 02/96] enabled simple pa unit tests pass --- .../intel_gpu/src/graph/debug_helper.cpp | 27 +++ .../src/graph/impls/cm/pa_single_token.cm | 1 - .../src/graph/impls/cm/paged_attention.cpp | 6 +- .../graph/impls/cm/paged_attention_gen.cpp | 225 ++++++++++++------ .../intel_gpu/src/graph/paged_attention.cpp | 4 +- .../src/plugin/transformations_pipeline.cpp | 2 +- .../tests/common/random_generator.hpp | 10 + .../test_cases/paged_attention_gpu_test.cpp | 99 ++++++-- 8 files changed, 278 insertions(+), 96 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/debug_helper.cpp b/src/plugins/intel_gpu/src/graph/debug_helper.cpp index a32a44b3fa1f45..068570dec5e579 100644 --- a/src/plugins/intel_gpu/src/graph/debug_helper.cpp +++ b/src/plugins/intel_gpu/src/graph/debug_helper.cpp @@ -491,6 +491,33 @@ NodeDebugHelper::~NodeDebugHelper() { log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw); } } + + for (size_t i = 0; i < m_inst.inputs_memory_count(); i++) { + std::string name = get_file_prefix() + "_updated_src_" + std::to_string(i); + auto output_mem = m_inst.input_memory_ptr(i); + if (output_mem == nullptr) { + GPU_DEBUG_COUT << " updated_input_mem is nullptr. Nothing to dump." << std::endl; + continue; + } + + auto& output_layout = m_inst.get_input_layout(i); + if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary) { + // Binary dump : raw + auto filename = get_file_path_for_binary_dump(output_layout, name, config.get_dump_tensors_path()); + + mem_lock lock(output_mem, m_stream); + ov::util::save_binary(filename, lock.data(), output_mem->size()); + GPU_DEBUG_COUT << " Dump layer dst : " << layer_name << " to " << filename << std::endl; + debug_str_for_bin_load += (filename + ","); + } else { + const bool dump_raw = config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::text_raw; + GPU_DEBUG_COUT << " Dump " << (dump_raw ? "raw " : "") << name << std::endl; + auto filename = config.get_dump_tensors_path() + get_name_for_dump(name) + ".txt"; + // Text dump + log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw); + } + } + if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary && m_inst.is_input()) { debug_str_for_bin_load[debug_str_for_bin_load.size()-1] = '\"'; GPU_DEBUG_COUT << debug_str_for_bin_load << std::endl;; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index 8adea7ad4f9a22..d1636173c9b4a9 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -98,7 +98,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( return; } const uint total_blocks_num = (kv_len + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE; - //#TODO: int8 compression data uint kv_pitch = HEAD_SIZE * sizeof(half); //# fp16 data diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index fc3c18e39f3dfe..12edf02ffb4a8a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -31,7 +31,9 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { Stage::Ptr pa_single_token_finalization = make_stage(); Stage::Ptr pa_multi_token = make_stage(); - PagedAttentionCmImpl(): PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) {} + PagedAttentionCmImpl(): PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) { + m_rt_params = std::make_unique(); + } explicit PagedAttentionCmImpl(const kernel_impl_params& params) : PagedAttentionCmImpl() { const auto desc = params.typed_desc(); @@ -45,7 +47,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { void update_rt_params(const primitive_inst& instance) override { update_stages_flags(instance); if (m_rt_params == nullptr) { - m_rt_params = std::make_unique(); + m_rt_params = std::make_unique(); } std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::update_rt_params()" << std::endl; const auto& params = *instance.get_impl_params(); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 8b271270b5f9c4..2f910029e11488 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -42,7 +42,8 @@ inline size_t get_q_step(size_t arch, bool is_single_token = false) { if (arch == 1) { return is_single_token ? 1 : 8; // For Xe1 } else if (arch == 2) { - return is_single_token ? 1 : 32; // For Xe2 + // For Xe2, q_step = CM_GRF_WIDTH / 32 + return is_single_token ? 1 : 16; // For Xe2 } OPENVINO_ASSERT(false, "Unsupported architecture for Q step"); return 0; // Fallback case, should not be reached @@ -206,6 +207,9 @@ PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_par return PagedAttentionStage::UNKNOWN; } +//----------------------------------------------------------------------------------------------------------------- +// Base generator +//----------------------------------------------------------------------------------------------------------------- JitConstants PagedAttentionGeneratorBase::get_jit_constants(const kernel_impl_params& params) const { auto jit = KernelGenerator::get_jit_constants(params); jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); @@ -223,20 +227,136 @@ JitConstants PagedAttentionGeneratorBase::get_jit_constants(const kernel_impl_pa return jit; } +JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + + const auto desc = params.typed_desc(); + jit.make("KV_HEADS_NUM", desc->kv_heads_num); + jit.make("K_HEAD_SIZE", desc->k_head_size); + jit.make("V_HEAD_SIZE", desc->v_head_size); + jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size); + jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size); + jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + + return jit; +} + +//----------------------------------------------------------------------------------------------------------------- +// KV cache update generator +//----------------------------------------------------------------------------------------------------------------- +Arguments PagedAttentionGeneratorKVCacheUpdate::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE}); // keys cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // values cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // keys cache + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // key_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // value_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // batch_size_in_sequences + return args; +} + +DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + const auto desc = params.typed_desc(); + // auto rtp = static_cast(rt_params); + + const size_t kv_len = get_max_context_len(params); + const size_t kv_heads_num = desc->kv_heads_num; + const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; + + wgs.global = {1, kv_heads_num, wg_count * WG_SIZE}; + wgs.local = {1, 1, WG_SIZE}; + + auto& scalars = kd.params.scalars; + // TODO: how to get pitch for dynamic_padding? + size_t key_pitch = desc->k_head_size * kv_heads_num; + size_t value_pitch = desc->v_head_size * kv_heads_num; + auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; + auto value_layout = params.input_layouts[PagedAttentionInputIdx::VALUE]; + + { + std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " + << "key_layout: " << key_layout.to_string() << ", value_layout: " << value_layout.to_string() << std::endl; + std::cout << "\tkey_dims = ["; + for (auto& it : key_layout.get_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tkey_pads = ["; + for (auto& it : key_layout.get_padded_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tvalue_dims = ["; + for (auto& it : value_layout.get_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tvalue_pads = ["; + for (auto& it : value_layout.get_padded_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + } + + auto get_simple_pitch = [](const layout& layout) { + size_t pitch = 1; + auto dims_padding = layout.get_padded_dims(); + for(size_t i = dims_padding.size() - 1; i > 0; --i) { + pitch = dims_padding[i]; + if(pitch > 1) { + break; + } + } + return pitch; + }; + key_pitch = get_simple_pitch(key_layout); + value_pitch = get_simple_pitch(value_layout); + std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " + << "key_pitch: " << key_pitch << ", value_pitch: " << value_pitch << std::endl; + + // TODO: support multiple sequences + size_t batch_size_in_sequences = 1; + std::vector scaler_value = {key_pitch, value_pitch, batch_size_in_sequences}; + scalars.resize(scaler_value.size()); + + // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " + // << "batch: " << batch << ", heads_num: " << heads_num << ", split_num: " << split_num << ", kv_len: " << kv_len << std::endl; + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// multi token generator +//----------------------------------------------------------------------------------------------------------------- Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_impl_params& params) const { const auto desc = params.typed_desc(); Arguments args; + // Doesn't support Query with dynamic_padding args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // query - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY}); // key - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE}); // value + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // key_cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // value_cache args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // past_lens args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block_indices args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block_indices_begins args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins -# if PA_SPARSE_BLOCK_SIZE > 1 +#if PA_SPARSE_BLOCK_SIZE > 1 args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SPARSE_BLOCK_MASK}); // sparse_block_mask -# endif +#endif args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len @@ -273,6 +393,22 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con // assert(rt_params != nullptr); const size_t heads_num = desc->heads_num; + auto query_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; + + { + std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: query_layout: " << query_layout.to_string() << std::endl; + std::cout << "\tquery_dims = ["; + for (auto& it : query_layout.get_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tquery_pads = ["; + for (auto& it : query_layout.get_padded_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + } + auto out_shape = params.output_layouts[0].get_shape(); const size_t batch = out_shape.size() < 4 ? 1 : out_shape[0]; const size_t q_len = out_shape[0]; @@ -285,6 +421,15 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con wgs.global = {batch, heads_num, wg_count * WG_SIZE}; wgs.local = {1, 1, WG_SIZE}; + { + std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: \n" + << "\tbatch: " << batch << ", heads_num: " << heads_num << ", q_len: " << q_len + << ", q_step: " << q_step << ", wg_seq_len: " << wg_seq_len << ", wg_count: " << wg_count + << ", global_work_size: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" + << ", local_work_size: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" + << std::endl; + } + std::vector scaler_value = {q_len}; scalars.resize(scaler_value.size()); for (size_t i = 0; i < scaler_value.size(); ++i) { @@ -294,6 +439,9 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con }}; } +//----------------------------------------------------------------------------------------------------------------- +// single token generator +//----------------------------------------------------------------------------------------------------------------- JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_impl_params& params) const { auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); @@ -318,7 +466,6 @@ Arguments PagedAttentionGeneratorSingleToken::get_arguments_desc(const kernel_im const auto desc = params.typed_desc(); // const auto has_scale_input = !desc->scale_val.has_value(); const auto has_scores_output = params.output_layouts.size() > 1; - OPENVINO_ASSERT(!has_scores_output, "[GPU][CM] PagedAttentionGeneratorSingleToken with scores output is not supported yet"); args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // queries @@ -370,6 +517,9 @@ DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() co }}; } +//----------------------------------------------------------------------------------------------------------------- +// single token finalization generator +//----------------------------------------------------------------------------------------------------------------- JitConstants PagedAttentionGeneratorSingleTokenFinalization::get_jit_constants(const kernel_impl_params& params) const { auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); const auto desc = params.typed_desc(); @@ -426,67 +576,4 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da }}; } -JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kernel_impl_params& params) const { - auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); - - const auto desc = params.typed_desc(); - jit.make("KV_HEADS_NUM", desc->kv_heads_num); - jit.make("K_HEAD_SIZE", desc->k_head_size); - jit.make("V_HEAD_SIZE", desc->v_head_size); - jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size); - jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size); - jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); - - return jit; -} - -Arguments PagedAttentionGeneratorKVCacheUpdate::get_arguments_desc(const kernel_impl_params& params) const { - Arguments args; - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY}); // queries - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE}); // keys cache - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // values cache - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence begins - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // queries - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // keys cache - - // scalar - args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // key_pitch - args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // value_pitch - args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // batch_size_in_sequences - return args; -} - -DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() const { - return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { - assert(!params.is_dynamic()); - auto& wgs = kd.params.workGroups; - const auto desc = params.typed_desc(); - // auto rtp = static_cast(rt_params); - - const size_t kv_len = get_max_context_len(params); - const size_t kv_heads_num = desc->kv_heads_num; - const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; - - wgs.global = {1, kv_heads_num, wg_count * WG_SIZE}; - wgs.local = {1, 1, WG_SIZE}; - - auto& scalars = kd.params.scalars; - size_t key_pitch = desc->k_head_size * kv_heads_num; - size_t value_pitch = desc->v_head_size * kv_heads_num; - size_t batch_size_in_sequences = kv_len; - std::vector scaler_value = {key_pitch, value_pitch, batch_size_in_sequences}; - scalars.resize(scaler_value.size()); - - // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " - // << "batch: " << batch << ", heads_num: " << heads_num << ", split_num: " << split_num << ", kv_len: " << kv_len << std::endl; - - for (size_t i = 0; i < scaler_value.size(); ++i) { - scalars[i].t = ScalarDescriptor::Types::INT32; - scalars[i].v.s32 = static_cast(scaler_value[i]); - } - }}; -} - } // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index e4d566b2a87664..2ec8e127b26cc0 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -47,9 +47,9 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no "[GPU] Paged Attention key cache quantization mode mismatch: prim.is_key_by_channel : ", desc->is_key_by_channel, " but exec_config : ", impl_param.get_program().get_config().get_key_cache_quant_mode()); bool valid_block_size = key_cache_ps.is_dynamic() || - (key_cache_ps[key_cache_idx].get_length() == static_cast(expected_block_size)); + (key_cache_ps[key_cache_idx-1].get_length() == static_cast(expected_block_size)); OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation for key cache quant mode " - , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[key_cache_idx].get_length()); + , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[key_cache_idx-1].get_length()); std::vector output_layouts{ data_layout }; if (desc->has_scores_output()) { diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index d96bbbe433f709..e51f9c1b180b74 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -513,7 +513,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); kv_cache_config.inferencePrecision = infer_precision; kv_cache_config.keyCacheBlockSize = 16; - kv_cache_config.keyCacheDimOrder = {0, 1, 3, 2}; + kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; kv_cache_config.valueCacheBlockSize = 16; diff --git a/src/plugins/intel_gpu/tests/common/random_generator.hpp b/src/plugins/intel_gpu/tests/common/random_generator.hpp index 8dfb4a616b1c6d..92f4b32591cb6d 100644 --- a/src/plugins/intel_gpu/tests/common/random_generator.hpp +++ b/src/plugins/intel_gpu/tests/common/random_generator.hpp @@ -57,6 +57,16 @@ class random_generator { return v; } + template + std::vector generate_random_1d_fixed(size_t a, int start, int step, int k = 100) { + std::vector v(a); + + for (size_t i = 0; i < a; ++i) { + v[i] = static_cast(start + i * step) / k; + } + return v; + } + template std::vector> generate_random_2d(size_t a, size_t b, int min, int max, int k = 8) { std::vector> v(a); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index f8339cb9e46e51..ee1d88e351cf21 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -218,6 +218,66 @@ struct PagedAttentionManager { return get_QKV_memory(value_data, v_head_size, true); } +#if ENABLE_PA_CM_PATH + memory::ptr get_key_cache_memory() { + auto key_cache_dt = data_types::f16; + auto adjusted_head_size = k_head_size; + if (kv_cache_compression) { + key_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; + auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(key_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * v_head_size + + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + +#else memory::ptr get_key_cache_memory() { auto key_cache_dt = data_types::f16; auto adjusted_head_size = k_head_size; @@ -315,6 +375,7 @@ struct PagedAttentionManager { return memory; } +#endif memory::ptr get_value_cache_memory() { auto value_cache_dt = data_types::f16; @@ -530,6 +591,9 @@ struct PagedAttentionManager { const size_t total_elements_num = tokens_num * num_heads * k_head_size; auto data = rg.generate_random_1d(total_elements_num, -1, 1); + // test code + //auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 500); + return data; } @@ -966,7 +1030,11 @@ struct PagedAttentionTest : public ::testing::TestWithParam { query_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); key_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); value_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.v_head_size }); +#if ENABLE_PA_CM_PATH + key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.k_head_size }); +#else key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.k_head_size, p.block_size }); +#endif value_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.v_head_size }); past_lens_layout.set_partial_shape(ov::PartialShape{ -1 }); subsequence_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); @@ -1129,6 +1197,9 @@ struct PagedAttentionTest : public ::testing::TestWithParam { if (data_output_mem) { ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); mem_lock mem_ptr(data_output_mem, get_test_stream()); + // for (size_t i = 0; i < data_output_mem->count(); i++) { + // std::cout << i << ": result = " << mem_ptr[i] << ", reference = " << ref_data.first[i] << std::endl; + // } for (size_t i = 0; i < data_output_mem->count(); i++) { ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance) << " at index=" << i; } @@ -1182,28 +1253,14 @@ const auto DISABLE_FA_V2 = true; INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ #if ENABLE_PA_CM_PATH - /* without scores output, dynamic input query paddings */ - paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - - paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ + paged_attention_test_params{ {{32, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{1024, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + + paged_attention_test_params{ {{1, 31}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 1023}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 127}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 128}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token #else /* with scores output, use SnapKV */ paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token From 8947906bb3a2a01ebba2c46595156f64e175d0c6 Mon Sep 17 00:00:00 2001 From: "river.li" Date: Mon, 1 Sep 2025 00:36:10 +0800 Subject: [PATCH 03/96] Fix 2nd_token issue --- .../src/graph/impls/cm/pa_single_token.cm | 53 ++++++++++---- .../src/graph/impls/cm/paged_attention.cpp | 18 ++--- .../src/graph/impls/cm/paged_attention.hpp | 15 ++-- .../graph/impls/cm/paged_attention_gen.cpp | 69 +++++++------------ .../test_cases/paged_attention_gpu_test.cpp | 12 ++-- 5 files changed, 90 insertions(+), 77 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index d1636173c9b4a9..6c795bce53d966 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -48,6 +48,19 @@ #define KV_SCALE_ZP_SIZE 0 // 4: scale/zp size +template +void show(matrix mat) { + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} + + extern "C" _GENX_MAIN_ void KERNEL_NAME( // extern "C" _GENX_MAIN_ void cm_sdpa_2nd( half* query [[type("svmptr_t")]], @@ -95,9 +108,11 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * (KV_PARTITION_SIZE / KV_BLOCK_SIZE); if(partition_idx * KV_PARTITION_SIZE > kv_len) { + // printf("WG exit: partition_idx=%d, KV_PARTITION_SIZE=%d, kv_len=%d\n", partition_idx, KV_PARTITION_SIZE, kv_len); return; } const uint total_blocks_num = (kv_len + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE; + //#TODO: int8 compression data uint kv_pitch = HEAD_SIZE * sizeof(half); //# fp16 data @@ -118,11 +133,14 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const uint per_kv_block_element_num = KV_BLOCK_SIZE * KV_HEADS_NUM * (HEAD_SIZE + KV_SCALE_ZP_SIZE / sizeof(half)); // 4: scale/zp uint block_num = KV_PARTITION_SIZE / KV_BLOCK_SIZE; + uint leftover_aligned_size = 0; uint leftover_size = 0; + if(partition_idx == kv_partition_num - 1) { + leftover_size = (kv_len - KV_PARTITION_SIZE * partition_idx) % KV_PARTITION_SIZE; + leftover_aligned_size = KV_STEP * ((leftover_size + KV_STEP - 1) / KV_STEP); // round up to KV_STEP + } if(block_num > total_blocks_num - start_block_idx) { block_num = total_blocks_num - start_block_idx; - leftover_size = kv_len - KV_PARTITION_SIZE * partition_idx; - leftover_size = KV_STEP * ((leftover_size + KV_STEP - 1) / KV_STEP); // round up to KV_STEP } //# rS = Q @ Kt @@ -153,8 +171,8 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #endif uint kv_pos_end = KV_BLOCK_SIZE; - if(block_idx == block_num - 1 && leftover_size > 0) { - kv_pos_end = leftover_size % KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && start_block_idx > 0) { + kv_pos_end = start_block_idx % KV_BLOCK_SIZE; } for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += KV_STEP, ki++) { auto rSvec = rS[ki].format(); @@ -190,8 +208,10 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } } - // printf("rS:\n"); - // show(rS); + //if(wg_thread_id==1) { + // printf("rS:\n"); + // show(rS); + //} // online softmax float cur_sum = 0.0f; @@ -210,9 +230,18 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( rS = cm_mul(rS, (float)SCALE_FACTOR); // convert scale_factor into (float), or it will be promoted to double //rS = cm_add(rS, MaskMat); + // printf("leftover_size = %d, leftover_aligned_size = %d, XE_ARCH = %d, PARTITION_SUBBLOCK_NUM * REG_N = %d\n", leftover_size, leftover_aligned_size, XE_ARCH, PARTITION_SUBBLOCK_NUM * REG_N); + if(leftover_size > 0) { + auto Svec = rS.format(); + for(int i = leftover_size; i < PARTITION_SUBBLOCK_NUM * REG_N; i++){ + Svec[i] = -10000000000.0f; + } + } + // compute lse constexpr float log2e = 1.4426950408889634f; vector rS_exp = cm_exp(rS.format()*log2e); + cur_lse += cm_sum(rS_exp); // compute row_max @@ -233,7 +262,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( cur_sum = cm_sum(rPv[0]); } - //if(wg_thread_id==0) { + // if(wg_thread_id==kv_partition_num - 1 && head_num_idx == 0) { // printf("Pmat:\n"); // show(Pmat); //} @@ -258,8 +287,8 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #endif uint kv_pos_end = KV_BLOCK_SIZE; - if(block_idx == block_num - 1 && leftover_size > 0) { - kv_pos_end = leftover_size % KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && start_block_idx > 0) { + kv_pos_end = start_block_idx % KV_BLOCK_SIZE; } for(int kv_pos =0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { uint kv_offset_y = kv_pos; @@ -288,11 +317,11 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } } } - - //if(wg_thread_id==0) { + + // if(wg_thread_id==kv_partition_num - 1) { // printf("Omat:\n"); // show(Omat); - //} + // } //# save Output matrix cur_O_f16; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 12edf02ffb4a8a..bb00c86cbc1e71 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -37,7 +37,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { explicit PagedAttentionCmImpl(const kernel_impl_params& params) : PagedAttentionCmImpl() { const auto desc = params.typed_desc(); - std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::PagedAttentionCmImpl()" << std::endl; + GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::PagedAttentionCmImpl()" << std::endl; add_stage(kv_cache_update, params); add_stage(pa_single_token, params); add_stage(pa_single_token_finalization, params); @@ -49,7 +49,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { if (m_rt_params == nullptr) { m_rt_params = std::make_unique(); } - std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::update_rt_params()" << std::endl; + GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::update_rt_params()" << std::endl; const auto& params = *instance.get_impl_params(); auto rt_params = static_cast(m_rt_params.get()); const auto& desc = params.typed_desc(); @@ -60,8 +60,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { rt_params->num_of_partitions = ceil_div(max_context_len, rt_params->partition_size); rt_params->stage = get_paged_attention_stage(params); - std::cout << " max_context_len: " << rt_params->max_context_len << " partition_size: " << rt_params->partition_size - << " num_of_partitions: " << rt_params->num_of_partitions << ", stage: " << static_cast(rt_params->stage) << std::endl; + GPU_DEBUG_TRACE_DETAIL << " max_context_len: " << rt_params->max_context_len << " partition_size: " << rt_params->partition_size + << " num_of_partitions: " << rt_params->num_of_partitions << ", stage: " << static_cast(rt_params->stage) << std::endl; } // update impl_parameter and rt_parameter @@ -78,7 +78,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto rt_params = static_cast(m_rt_params.get()); assert(rt_params != nullptr); - std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::execute(): stage = " << static_cast(rt_params->stage) << std::endl; + GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::execute(): stage = " << static_cast(rt_params->stage) << std::endl; std::vector res_event = events; res_event = {execute_stage(res_event, instance, kv_cache_update)}; @@ -121,8 +121,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { partition_size = get_partition_size(); num_of_partitions = ceil_div(max_context_len, partition_size); } - std::cout << "ov::intel_gpu::cm::PagedAttentionCmImpl::get_internal_buffer_descs(): stage = " << static_cast(stage) - << " partition_size: " << partition_size << " num_of_partitions: " << num_of_partitions << std::endl; + GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::get_internal_buffer_descs(): stage = " << static_cast(stage) + << " partition_size: " << partition_size << " num_of_partitions: " << num_of_partitions << std::endl; if (stage == PagedAttentionStage::GENERATE) { const auto& input = params.input_layouts[0]; const int64_t total_tokens = input.get_partial_shape()[0].get_length(); @@ -132,8 +132,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { internal_buffers.emplace_back(tmp_out_elements_count * element_size, indexes_dt); // 0: intermediate partition output internal_buffers.emplace_back(buf_elements_count * element_size, indexes_dt); // 1: softmax exp_sums - std::cout << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * element_size << " exp_sums=" << buf_elements_count * element_size - << std::endl; + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * element_size + << " exp_sums=" << buf_elements_count * element_size << std::endl; } else { internal_buffers.emplace_back(16, indexes_dt); // 0: intermediate partition output internal_buffers.emplace_back(16, indexes_dt); // 1: softmax exp_sums diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp index e5909242813758..7b28c001879f23 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp @@ -6,10 +6,11 @@ #include #include -#include "program_node.h" + #include "intel_gpu/runtime/layout.hpp" -#include "registry/implementation_manager.hpp" #include "paged_attention_inst.h" +#include "program_node.h" +#include "registry/implementation_manager.hpp" using namespace cldnn; // TODO: Remove once namespaces are aligned @@ -34,7 +35,7 @@ struct PagedAttentionImplementationManager : public ImplementationManager { const auto& info = engine.get_device_info(); // CM optimized for systolic-array architectures if (!check_cm_jit_support(engine, config) || !info.supports_immad || !config.get_use_cm()) { - std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; return false; } @@ -43,21 +44,21 @@ struct PagedAttentionImplementationManager : public ImplementationManager { const auto& v_layout = node.get_input_layout(2); const auto& out_layout = node.get_output_layout(0); if (!everyone_is(format::bfyx, q_layout.format, k_layout.format, v_layout.format, out_layout.format)) { - std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; return false; } if (!one_of(k_layout.data_type, supported_kv_types) || !one_of(v_layout.data_type, supported_kv_types)) { - std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; return false; } if (!one_of(q_layout.data_type, supported_q_types) || !one_of(out_layout.data_type, supported_q_types)) { - std::cout << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; return false; } - std::cout << "ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - true" << std::endl; + GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - true" << std::endl; return true; } }; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 2f910029e11488..b2c1af29652bf2 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -26,6 +26,8 @@ constexpr size_t WG_SIZE = 16; constexpr size_t reduce_split_step = 16; } // namespace +#define DEBUG_ENABLED 1 + // This function returns the kv_step and kv_split_len based on the architecture. // return {kv_step, kv_split_len} inline std::pair get_kv_split_size(size_t arch) { @@ -213,7 +215,7 @@ PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_par JitConstants PagedAttentionGeneratorBase::get_jit_constants(const kernel_impl_params& params) const { auto jit = KernelGenerator::get_jit_constants(params); jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); - std::cout << "PagedAttentionGeneratorBase::get_jit_constants: " << get_entry_point(params) << std::endl; + // std::cout << "PagedAttentionGeneratorBase::get_jit_constants: " << get_entry_point(params) << std::endl; // auto desc = params.typed_desc(); auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; @@ -277,37 +279,11 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() wgs.local = {1, 1, WG_SIZE}; auto& scalars = kd.params.scalars; - // TODO: how to get pitch for dynamic_padding? size_t key_pitch = desc->k_head_size * kv_heads_num; size_t value_pitch = desc->v_head_size * kv_heads_num; auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; auto value_layout = params.input_layouts[PagedAttentionInputIdx::VALUE]; - { - std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " - << "key_layout: " << key_layout.to_string() << ", value_layout: " << value_layout.to_string() << std::endl; - std::cout << "\tkey_dims = ["; - for (auto& it : key_layout.get_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tkey_pads = ["; - for (auto& it : key_layout.get_padded_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tvalue_dims = ["; - for (auto& it : value_layout.get_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tvalue_pads = ["; - for (auto& it : value_layout.get_padded_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - } - auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); @@ -321,17 +297,17 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() }; key_pitch = get_simple_pitch(key_layout); value_pitch = get_simple_pitch(value_layout); - std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " - << "key_pitch: " << key_pitch << ", value_pitch: " << value_pitch << std::endl; + + if (DEBUG_ENABLED) { // Debug + std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " + << "key_pitch: " << key_pitch << ", value_pitch: " << value_pitch << std::endl; + } // TODO: support multiple sequences size_t batch_size_in_sequences = 1; std::vector scaler_value = {key_pitch, value_pitch, batch_size_in_sequences}; scalars.resize(scaler_value.size()); - // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " - // << "batch: " << batch << ", heads_num: " << heads_num << ", split_num: " << split_num << ", kv_len: " << kv_len << std::endl; - for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::INT32; scalars[i].v.s32 = static_cast(scaler_value[i]); @@ -395,7 +371,7 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con auto query_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; - { + if (DEBUG_ENABLED) { // Debug std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: query_layout: " << query_layout.to_string() << std::endl; std::cout << "\tquery_dims = ["; for (auto& it : query_layout.get_dims()) { @@ -421,13 +397,12 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con wgs.global = {batch, heads_num, wg_count * WG_SIZE}; wgs.local = {1, 1, WG_SIZE}; - { + if (DEBUG_ENABLED) { // Debug std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: \n" - << "\tbatch: " << batch << ", heads_num: " << heads_num << ", q_len: " << q_len - << ", q_step: " << q_step << ", wg_seq_len: " << wg_seq_len << ", wg_count: " << wg_count - << ", global_work_size: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" - << ", local_work_size: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" - << std::endl; + << "\tbatch: " << batch << ", heads_num: " << heads_num << ", q_len: " << q_len << ", q_step: " << q_step + << ", wg_seq_len: " << wg_seq_len << ", wg_count: " << wg_count << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " + << wgs.global[2] << "]" + << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } std::vector scaler_value = {q_len}; @@ -507,8 +482,12 @@ DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() co std::vector scaler_value = {1}; scalars.resize(scaler_value.size()); - // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " - // << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", kv_len: " << kv_len << std::endl; + if (DEBUG_ENABLED) { // Debug + std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " + << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", gws: [" << wgs.global[0] << ", " + << wgs.global[1] << ", " << wgs.global[2] << "]" + << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + } for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::INT32; @@ -566,8 +545,12 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da std::vector scaler_value = {partition_num}; scalars.resize(scaler_value.size()); - // std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " - // << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", kv_len: " << kv_len << std::endl; + if (DEBUG_ENABLED) { // Debug + std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " + << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", gws: [" << wgs.global[0] << ", " + << wgs.global[1] << ", " << wgs.global[2] << "]" + << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + } for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::INT32; diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index ee1d88e351cf21..82f6888f668a57 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -235,7 +235,7 @@ struct PagedAttentionManager { for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { int past_len = subsequence_descs[i].past_len; if (past_len != 0) { - int blocks_num = ceil_div(past_len, block_size); + int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size @@ -298,7 +298,7 @@ struct PagedAttentionManager { for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { int past_len = subsequence_descs[i].past_len; if (past_len != 0) { - int blocks_num = ceil_div(past_len, block_size); + int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size @@ -393,7 +393,7 @@ struct PagedAttentionManager { for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { int past_len = subsequence_descs[i].past_len; if (past_len != 0) { - int blocks_num = ceil_div(past_len, block_size); + int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size @@ -592,7 +592,7 @@ struct PagedAttentionManager { auto data = rg.generate_random_1d(total_elements_num, -1, 1); // test code - //auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 500); + // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); return data; } @@ -952,7 +952,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { public: random_generator rg; cldnn::engine& engine = get_test_engine(); - float tolerance = 2e-3; + float tolerance = 2e-2; void SetUp() override { rg.set_seed(GET_SUITE_NAME); @@ -1260,7 +1260,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{1, 31}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 1023}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 127}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 128}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 129}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token #else /* with scores output, use SnapKV */ paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token From 83dba292f143e8184ad43f3a5567eb9f1a0719bd Mon Sep 17 00:00:00 2001 From: "river.li" Date: Tue, 2 Sep 2025 08:44:34 +0800 Subject: [PATCH 04/96] Fixed pipeline output corruption issue 1. kvcache update's k/v offset issue 2. 2nd token lse data overflow issue --- .../graph/impls/cm/pa_kv_cache_update_ref.cm | 12 +++- .../src/graph/impls/cm/pa_single_token.cm | 47 ++++++++------ .../impls/cm/pa_single_token_finalization.cm | 11 +++- .../src/graph/impls/cm/paged_attention.cpp | 10 ++- .../graph/impls/cm/paged_attention_gen.cpp | 64 ++++++++++++++++--- .../test_cases/paged_attention_gpu_test.cpp | 1 + 6 files changed, 106 insertions(+), 39 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm index 18623b400a31e8..c9198a8c89dfac 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -35,7 +35,9 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( half* key_cache [[type("svmptr_t")]], half* value_cache [[type("svmptr_t")]], uint32_t key_pitch, + uint32_t key_offset, uint32_t value_pitch, + uint32_t value_offset, uint32_t batch_size_in_sequences) { // # key: [batch_size_in_tokens, num_kv_heads * k_head_size] // # value [batch_size_in_tokens, num_kv_heads * v_head_size] @@ -85,7 +87,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( { uint block_k_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; uint key_out_offset = block_k_base_offset + token_start_pos * K_HEAD_SIZE; - uint key_in_offset = token_idx * key_pitch + head_idx * K_HEAD_SIZE; + uint key_in_offset = token_idx * key_pitch + head_idx * K_HEAD_SIZE + key_offset; vector key_data; key_data.format() = cm_ptr_load((int*)key, key_in_offset * (int)sizeof(half)); @@ -94,7 +96,13 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( { uint block_v_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_V_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; uint value_out_offset = block_v_base_offset + token_start_pos * V_HEAD_SIZE; - uint value_in_offset = token_idx * value_pitch + head_idx * V_HEAD_SIZE; + uint value_in_offset = token_idx * value_pitch + head_idx * V_HEAD_SIZE + value_offset; + + //if(token_idx==0 && head_idx==0) + //{ + // printf("value_pitch = %d, value_in_offset: %d, value_out_offset: %d,V_HEAD_SIZE = %d, ADJUSTED_V_HEAD_SIZE = %d\n", + // value_pitch, value_in_offset, value_out_offset, V_HEAD_SIZE, ADJUSTED_V_HEAD_SIZE); + //} vector value_data; value_data.format() = cm_ptr_load((int*)value, value_in_offset * (int)sizeof(half)); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index 6c795bce53d966..dc7a10d766ace2 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -15,14 +15,12 @@ *******************************************************************************/ // xe-1:8, xe-2:16 -#if xe_arch==1 +#if XE_ARCH==1 #define REG_N 8 #define USE_LSC_BLOCK_2D_DESC 0 -// #define KV_STEP 8 #else #define REG_N 16 #define USE_LSC_BLOCK_2D_DESC 1 -// #define KV_STEP 16 #endif #define SystolicDepth 8 @@ -48,6 +46,9 @@ #define KV_SCALE_ZP_SIZE 0 // 4: scale/zp size + +#define DEBUG_ENABLE 0 +#if DEBUG_ENABLE template void show(matrix mat) { for(int m = 0; m < M; m ++) { @@ -59,6 +60,7 @@ void show(matrix mat) { } printf("]\n"); } +#endif extern "C" _GENX_MAIN_ void KERNEL_NAME( @@ -71,9 +73,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( int* block_indices_begins [[type("svmptr_t")]], int* subsequence_begins [[type("svmptr_t")]], half* output [[type("svmptr_t")]], - // half* mask [[type("svmptr_t")]], float* lse [[type("svmptr_t")]], - // int* gws_subseq_mapping [[type("svmptr_t")]], int q_len// 1 ) { //# batch=1, seq_num=1 or >1 @@ -125,7 +125,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( cm_svm_block_read((svmptr_t)(query + qo_offset + k), Qmat[ri].format()); } - // if(wg_thread_id==0 && head_num_idx == 0) { + //if(head_num_idx==0 && partition_idx==1) { // printf("Qmat loaded, wg_thread_id=%d\n", wg_thread_id); // show(Qmat); //} @@ -144,7 +144,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } //# rS = Q @ Kt - //# PARTITION_SUBBLOCK_NUM * [REG_M, REG_K] * [REG_K, REG_N] = PARTITION_SUBBLOCK_NUM * [REG_M, REG_N] + //# PARTITION_SUBBLOCK_NUM * [REG_M, REG_K] * [REG_K, REG_N] = PARTITION_SUBBLOCK_NUM * [REG_M, REG_N] matrix rS = 0; // # each WI can process multiple blocks for(uint block_idx = 0, ki = 0; block_idx < block_num; block_idx++) { @@ -208,10 +208,10 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } } - //if(wg_thread_id==1) { - // printf("rS:\n"); - // show(rS); - //} + // if(head_num_idx==0 && partition_idx==1) { + // printf("rS:\n"); + // show(rS); + // } // online softmax float cur_sum = 0.0f; @@ -222,27 +222,28 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( matrix Pmat = 0; #endif { - //# Load Mask into register - // matrix MaskMat; - // uint mask_offset = seq_idx * q_len * kv_len + wg_thread_id * KV_PARTITION_SIZE; - // cm_svm_block_read((svmptr_t)(mask + mask_offset), MaskMat.format()); - rS = cm_mul(rS, (float)SCALE_FACTOR); // convert scale_factor into (float), or it will be promoted to double - //rS = cm_add(rS, MaskMat); // printf("leftover_size = %d, leftover_aligned_size = %d, XE_ARCH = %d, PARTITION_SUBBLOCK_NUM * REG_N = %d\n", leftover_size, leftover_aligned_size, XE_ARCH, PARTITION_SUBBLOCK_NUM * REG_N); if(leftover_size > 0) { auto Svec = rS.format(); for(int i = leftover_size; i < PARTITION_SUBBLOCK_NUM * REG_N; i++){ - Svec[i] = -10000000000.0f; + Svec[i] = -3e38f; } } // compute lse constexpr float log2e = 1.4426950408889634f; + constexpr float loge2 = 0.6931471805599453f; vector rS_exp = cm_exp(rS.format()*log2e); + float cur_lse_0 = cm_sum(rS_exp); - cur_lse += cm_sum(rS_exp); + //if(head_num_idx==0 && partition_idx==1) { + // uint lse_offset = seq_idx * HEADS_NUM * kv_partition_num + head_num_idx * kv_partition_num + wg_thread_id; + // printf("LSE[%d]: %f\n", lse_offset, cur_lse); + // printf("rS_exp:\n"); + // show(rS_exp.format()); + //} // compute row_max auto rSv = rS.format(); @@ -257,6 +258,12 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( Pmat= cm_exp((rS - row_max)*log2e); #endif + vector rS_exp_temp = cm_exp((rS.format() - row_max)*log2e); + cur_lse = cm_sum(rS_exp_temp.format()); + cur_lse = cm_log(cur_lse) * loge2 + row_max; // log2(sum(exp(x))) = log2e * log(sum(exp(x))) + //float cur_lse_1 = cm_exp(cur_lse * log2e); + //printf("row_max= %f, cur_lse =%f, cur_lse_0 = %f, cur_lse_1 = %f\n", row_max, cur_lse, cur_lse_0, cur_lse_1); + // compute row sum of P auto rPv = Pmat.format(); cur_sum = cm_sum(rPv[0]); @@ -339,4 +346,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } uint lse_offset = seq_idx * HEADS_NUM * kv_partition_num + head_num_idx * kv_partition_num + wg_thread_id; lse[lse_offset] = cur_lse; + + // printf("LSE[%d]: %f\n", lse_offset, cur_lse); } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm index 001eb39ad1435c..c8217633df4fdb 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -30,10 +30,16 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #else float total_lse = 0.0; uint lse_offset = batch * total_partition_num + head * kv_partition_num; + constexpr float log2e = 1.4426950408889634f; float* lse_vec = lse + lse_offset; + float lse_max = lse_vec[0]; + for(int k = 1; k < kv_partition_num; k ++) { + lse_max = cm_max(lse_vec[k], lse_max); + } #pragma unroll for(int k = 0; k < kv_partition_num; k ++) { - total_lse += lse_vec[k]; + float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); + total_lse += lse_value; } #endif @@ -45,7 +51,8 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( for(int k = 0; k < kv_partition_num; k ++) { cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); input_offset += HEAD_SIZE; - out_mat += cm_mul(data_mat, (float)(lse_vec[k]/total_lse)); + float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); + out_mat += cm_mul(data_mat, (float)(lse_value/total_lse)); } // write output diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index bb00c86cbc1e71..7e487130ee0fcb 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -104,8 +104,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { std::vector internal_buffers; const auto desc = params.typed_desc(); - const auto indexes_dt = ov::element::u8; - const auto element_size = 4; // 4 bytes + const auto indexes_dt = ov::element::f32; auto stage = PagedAttentionStage::UNKNOWN; auto rt_params = static_cast(m_rt_params.get()); @@ -129,11 +128,10 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto buf_elements_count = static_cast(total_tokens * desc->heads_num * num_of_partitions); auto tmp_out_elements_count = static_cast(total_tokens * desc->heads_num * desc->v_head_size * num_of_partitions); - internal_buffers.emplace_back(tmp_out_elements_count * element_size, indexes_dt); // 0: intermediate partition output - internal_buffers.emplace_back(buf_elements_count * element_size, indexes_dt); // 1: softmax exp_sums + internal_buffers.emplace_back(tmp_out_elements_count, ov::element::f16); // 0: intermediate partition output + internal_buffers.emplace_back(buf_elements_count, ov::element::f32); // 1: softmax exp_sums - GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * element_size - << " exp_sums=" << buf_elements_count * element_size << std::endl; + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 2 << " exp_sums=" << buf_elements_count * 4 << std::endl; } else { internal_buffers.emplace_back(16, indexes_dt); // 0: intermediate partition output internal_buffers.emplace_back(16, indexes_dt); // 1: softmax exp_sums diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index b2c1af29652bf2..59f6028aee6390 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -26,7 +26,7 @@ constexpr size_t WG_SIZE = 16; constexpr size_t reduce_split_step = 16; } // namespace -#define DEBUG_ENABLED 1 +#define DEBUG_ENABLED 0 // This function returns the kv_step and kv_split_len based on the architecture. // return {kv_step, kv_split_len} @@ -229,6 +229,9 @@ JitConstants PagedAttentionGeneratorBase::get_jit_constants(const kernel_impl_pa return jit; } +//----------------------------------------------------------------------------------------------------------------- +// KV cache update generator +//----------------------------------------------------------------------------------------------------------------- JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kernel_impl_params& params) const { auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); @@ -243,9 +246,6 @@ JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kerne return jit; } -//----------------------------------------------------------------------------------------------------------------- -// KV cache update generator -//----------------------------------------------------------------------------------------------------------------- Arguments PagedAttentionGeneratorKVCacheUpdate::get_arguments_desc(const kernel_impl_params& params) const { Arguments args; args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY}); // queries @@ -259,8 +259,10 @@ Arguments PagedAttentionGeneratorKVCacheUpdate::get_arguments_desc(const kernel_ // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // key_pitch - args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // value_pitch - args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // batch_size_in_sequences + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // key_offset + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // value_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // value_offset + args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // batch_size_in_sequences return args; } @@ -284,6 +286,31 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; auto value_layout = params.input_layouts[PagedAttentionInputIdx::VALUE]; + if (0) { // Debug + std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " + << "key_layout: " << key_layout.to_string() << ", value_layout: " << value_layout.to_string() << std::endl; + std::cout << "\tkey_dims = ["; + for (auto& it : key_layout.get_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tkey_pads = ["; + for (auto& it : key_layout.get_padded_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tvalue_dims = ["; + for (auto& it : value_layout.get_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tvalue_pads = ["; + for (auto& it : value_layout.get_padded_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + } + auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); @@ -298,14 +325,31 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() key_pitch = get_simple_pitch(key_layout); value_pitch = get_simple_pitch(value_layout); + auto get_simple_offset = [](const layout& layout) { + size_t offset = 0; + const auto& data_padding = layout.data_padding; + const auto& lower_pads = data_padding._lower_size; + for (auto& it : lower_pads) { + if (it > 0) { + offset = it; + break; + } + } + return offset; + }; + size_t key_offset = get_simple_offset(key_layout); + size_t value_offset = get_simple_offset(value_layout); + if (DEBUG_ENABLED) { // Debug std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " - << "key_pitch: " << key_pitch << ", value_pitch: " << value_pitch << std::endl; + << "key_pitch: " << key_pitch << ", key_offset: " << key_offset + << ", value_pitch: " << value_pitch << ", value_offset: " << value_offset + << std::endl; } // TODO: support multiple sequences size_t batch_size_in_sequences = 1; - std::vector scaler_value = {key_pitch, value_pitch, batch_size_in_sequences}; + std::vector scaler_value = {key_pitch, key_offset, value_pitch, value_offset, batch_size_in_sequences}; scalars.resize(scaler_value.size()); for (size_t i = 0; i < scaler_value.size(); ++i) { @@ -371,7 +415,7 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con auto query_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; - if (DEBUG_ENABLED) { // Debug + if (0 && DEBUG_ENABLED) { // Debug std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: query_layout: " << query_layout.to_string() << std::endl; std::cout << "\tquery_dims = ["; for (auto& it : query_layout.get_dims()) { @@ -419,7 +463,7 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con //----------------------------------------------------------------------------------------------------------------- JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_impl_params& params) const { auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); - jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); + // jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); auto desc = params.typed_desc(); const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)); const size_t kv_partition_size = get_partition_size(); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 82f6888f668a57..497d6d89b42a22 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1261,6 +1261,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{1, 1023}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 127}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 129}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 32}}, 28, 128, 128, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token #else /* with scores output, use SnapKV */ paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token From 2743aab004f21d028e34e9308abe1025bd1ff716 Mon Sep 17 00:00:00 2001 From: "river.li" Date: Tue, 2 Sep 2025 15:51:27 +0800 Subject: [PATCH 05/96] Fix 2nd non-16 alignment accuracy issue --- .../src/graph/impls/cm/pa_single_token.cm | 18 +++++++++--------- .../impls/cm/pa_single_token_finalization.cm | 10 ++++++---- .../src/graph/impls/cm/paged_attention.cpp | 4 ++-- .../src/graph/impls/cm/paged_attention_gen.cpp | 13 ++++++++----- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index dc7a10d766ace2..a2c8839189e622 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -72,7 +72,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( int* block_indices [[type("svmptr_t")]], int* block_indices_begins [[type("svmptr_t")]], int* subsequence_begins [[type("svmptr_t")]], - half* output [[type("svmptr_t")]], + float* output [[type("svmptr_t")]], float* lse [[type("svmptr_t")]], int q_len// 1 ) { @@ -171,8 +171,8 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #endif uint kv_pos_end = KV_BLOCK_SIZE; - if(block_idx == block_num - 1 && start_block_idx > 0) { - kv_pos_end = start_block_idx % KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && leftover_aligned_size > 0) { + kv_pos_end = leftover_aligned_size; } for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += KV_STEP, ki++) { auto rSvec = rS[ki].format(); @@ -294,8 +294,8 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #endif uint kv_pos_end = KV_BLOCK_SIZE; - if(block_idx == block_num - 1 && start_block_idx > 0) { - kv_pos_end = start_block_idx % KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && leftover_aligned_size > 0) { + kv_pos_end = leftover_aligned_size; } for(int kv_pos =0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { uint kv_offset_y = kv_pos; @@ -331,18 +331,18 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( // } //# save Output - matrix cur_O_f16; + matrix cur_O; uint o_offset = seq_idx * kv_partition_num * HEADS_NUM * HEAD_SIZE + kv_partition_num * head_num_idx * HEAD_SIZE + wg_thread_id * HEAD_SIZE; float div_cur_sum = 1.0/cur_sum; #pragma unroll for(int k = 0, ri=0; k < HEAD_SIZE; k += REG_N, ri++) { auto cO = Omat[ri].format(); #if XE_ARCH==1 - cur_O_f16= cm_mul(cO, div_cur_sum); + cur_O= cm_mul(cO, div_cur_sum); #else - cur_O_f16= cm_div_ieee(cO, cur_sum); + cur_O= cm_div_ieee(cO, cur_sum); #endif - cm_svm_block_write((svmptr_t)(output + o_offset + k),cur_O_f16.format()); + cm_svm_block_write((svmptr_t)(output + o_offset + k),cur_O.format()); } uint lse_offset = seq_idx * HEADS_NUM * kv_partition_num + head_num_idx * kv_partition_num + wg_thread_id; lse[lse_offset] = cur_lse; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm index c8217633df4fdb..a46e072100a83f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -11,7 +11,7 @@ //cm_sdpa_2nd_reduce extern "C" _GENX_MAIN_ void KERNEL_NAME( // extern "C" _GENX_MAIN_ void cm_sdpa_2nd_reduce( - half* input [[type("svmptr_t")]], // + float* input [[type("svmptr_t")]], half* output [[type("svmptr_t")]], float* lse [[type("svmptr_t")]], int kv_partition_num @@ -44,16 +44,18 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #endif // load input, total_partition_num = head_nums * kv_partition_num; + matrix out_mat_f32 = 0; matrix out_mat = 0; - matrix data_mat; + matrix data_mat; uint input_offset = batch * total_partition_num * HEAD_SIZE + head * kv_partition_num * HEAD_SIZE + offset; #pragma unroll for(int k = 0; k < kv_partition_num; k ++) { - cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); + cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); input_offset += HEAD_SIZE; float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); - out_mat += cm_mul(data_mat, (float)(lse_value/total_lse)); + out_mat_f32 += cm_mul(data_mat, (float)(lse_value/total_lse)); } + out_mat = out_mat_f32.format(); // write output uint output_offset = batch * HEADS_NUM * HEAD_SIZE + head * HEAD_SIZE + offset; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 7e487130ee0fcb..c51421721048bd 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -128,10 +128,10 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto buf_elements_count = static_cast(total_tokens * desc->heads_num * num_of_partitions); auto tmp_out_elements_count = static_cast(total_tokens * desc->heads_num * desc->v_head_size * num_of_partitions); - internal_buffers.emplace_back(tmp_out_elements_count, ov::element::f16); // 0: intermediate partition output + internal_buffers.emplace_back(tmp_out_elements_count, ov::element::f32); // 0: intermediate partition output internal_buffers.emplace_back(buf_elements_count, ov::element::f32); // 1: softmax exp_sums - GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 2 << " exp_sums=" << buf_elements_count * 4 << std::endl; + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 4 << " exp_sums=" << buf_elements_count * 4 << std::endl; } else { internal_buffers.emplace_back(16, indexes_dt); // 0: intermediate partition output internal_buffers.emplace_back(16, indexes_dt); // 1: softmax exp_sums diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 59f6028aee6390..205dab2b367062 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -342,9 +342,8 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() if (DEBUG_ENABLED) { // Debug std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " - << "key_pitch: " << key_pitch << ", key_offset: " << key_offset - << ", value_pitch: " << value_pitch << ", value_offset: " << value_offset - << std::endl; + << "kv_len: " << kv_len << ", key_pitch: " << key_pitch << ", key_offset: " << key_offset << ", value_pitch: " << value_pitch + << ", value_offset: " << value_offset << ", "<< std::endl; } // TODO: support multiple sequences @@ -527,9 +526,13 @@ DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() co scalars.resize(scaler_value.size()); if (DEBUG_ENABLED) { // Debug + size_t kv_len = get_kv_len(params, PagedAttentionStage::GENERATE); + size_t max_context_len = get_max_context_len(params); + size_t past_len = get_past_len(params, 0); std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " - << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", gws: [" << wgs.global[0] << ", " - << wgs.global[1] << ", " << wgs.global[2] << "]" + << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", kv_len: " << kv_len + << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] + << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } From 65b9cc7e89481ad3bb1a0202f9385aa7190641d7 Mon Sep 17 00:00:00 2001 From: "river.li" Date: Tue, 2 Sep 2025 21:42:03 +0800 Subject: [PATCH 06/96] Set best partition size for 2nd --- src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm | 4 ++-- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index a2c8839189e622..59978e1d615666 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -172,7 +172,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint kv_pos_end = KV_BLOCK_SIZE; if(block_idx == block_num - 1 && leftover_aligned_size > 0) { - kv_pos_end = leftover_aligned_size; + kv_pos_end = leftover_size % KV_BLOCK_SIZE; } for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += KV_STEP, ki++) { auto rSvec = rS[ki].format(); @@ -295,7 +295,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint kv_pos_end = KV_BLOCK_SIZE; if(block_idx == block_num - 1 && leftover_aligned_size > 0) { - kv_pos_end = leftover_aligned_size; + kv_pos_end = leftover_size % KV_BLOCK_SIZE; } for(int kv_pos =0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { uint kv_offset_y = kv_pos; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 205dab2b367062..9e2af75c56b9d2 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -157,8 +157,8 @@ size_t get_partition_size() { // size_t k_partition_blok_num = (kv_len + 8191) / 8192; // if (k_partition_blok_num < 1) // k_partition_blok_num = 1; - const size_t k_partition_blok_num = 1; - return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; + const size_t k_partition_blok_num = 16; + return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; // 128 } size_t get_partition_num(const size_t kv_len) { From c4a1659b3d735e604f28ed66ac8f2ba41e604012 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 3 Sep 2025 09:49:53 +0800 Subject: [PATCH 07/96] update KV_BLOCK_SIZE to 256 --- .../include/intel_gpu/primitives/paged_attention.hpp | 2 +- .../src/graph/impls/cm/paged_attention_gen.cpp | 5 +++-- .../src/graph/impls/cm/paged_attention_gen.hpp | 2 +- src/plugins/intel_gpu/src/graph/paged_attention.cpp | 10 +++++----- .../intel_gpu/src/plugin/transformations_pipeline.cpp | 4 ++-- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index f5fe561dc84046..962e90dcf3ffd1 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -38,7 +38,7 @@ struct paged_attention : public primitive_base { XATTENTION_STRIDE = 19, }; - static constexpr size_t block_size = 16; + static constexpr size_t block_size = 256; paged_attention() : primitive_base("", {}) {} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 9e2af75c56b9d2..4d508ac3064d88 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -157,8 +157,9 @@ size_t get_partition_size() { // size_t k_partition_blok_num = (kv_len + 8191) / 8192; // if (k_partition_blok_num < 1) // k_partition_blok_num = 1; - const size_t k_partition_blok_num = 16; - return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; // 128 + // const size_t k_partition_blok_num = 16; + // return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; // 128 + return 256; } size_t get_partition_num(const size_t kv_len) { diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index a3149054ec6617..ffabc6c9cac128 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -31,7 +31,7 @@ constexpr auto get_pa_build_options() { } // BLOCK_SIZE can be 16/32/64/128/256 -#define PA_KV_CACHE_BLOCK_SIZE 16 +#define PA_KV_CACHE_BLOCK_SIZE 256 // sparse attention block size is set to 1 to disable sparse attention support in CM kernels #define PA_SPARSE_BLOCK_SIZE 1 diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 2ec8e127b26cc0..1478eed6c0d8c9 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -110,11 +110,11 @@ paged_attention_inst::typed_primitive_inst(network& network, const paged_attenti : parent(network, node) { const auto desc = node.get_primitive(); - const auto k_head_size = desc->k_head_size; - const auto v_head_size = desc->v_head_size; + // const auto k_head_size = desc->k_head_size; + // const auto v_head_size = desc->v_head_size; const auto heads_num = desc->heads_num; const auto kv_heads_num = desc->kv_heads_num; - const auto pa_block_size = desc->block_size; + // const auto pa_block_size = desc->block_size; if (desc->has_alibi) { const auto alibi_input_idx = 11; @@ -123,7 +123,7 @@ paged_attention_inst::typed_primitive_inst(network& network, const paged_attenti } OPENVINO_ASSERT(heads_num % kv_heads_num == 0); - OPENVINO_ASSERT(k_head_size % pa_block_size == 0); - OPENVINO_ASSERT(v_head_size % pa_block_size == 0); + // OPENVINO_ASSERT(k_head_size % pa_block_size == 0); + // OPENVINO_ASSERT(v_head_size % pa_block_size == 0); } } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index e51f9c1b180b74..b9e1bb8d66e460 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -512,11 +512,11 @@ void TransformationsPipeline::apply(std::shared_ptr func) { kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); kv_cache_config.inferencePrecision = infer_precision; - kv_cache_config.keyCacheBlockSize = 16; + kv_cache_config.keyCacheBlockSize = 256; kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; - kv_cache_config.valueCacheBlockSize = 16; + kv_cache_config.valueCacheBlockSize = 256; kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; kv_cache_config.valueCacheQuantBychannel = false; kv_cache_config.valueCacheGroupSize = 0; From 62a222f90ab4de350d03bd9f00b0809eacb3e73d Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 3 Sep 2025 17:08:20 +0800 Subject: [PATCH 08/96] initiate xattention integration --- .../src/graph/impls/cm/include/estimate.hpp | 1107 +++++++++++++++++ .../src/graph/impls/cm/include/find_block.hpp | 197 +++ .../src/graph/impls/cm/include/sort.hpp | 244 ++++ .../src/graph/impls/cm/paged_attention.cpp | 26 + .../graph/impls/cm/paged_attention_gen.cpp | 243 +++- .../graph/impls/cm/paged_attention_gen.hpp | 44 +- .../src/graph/impls/cm/xattn_find_block.cm | 65 + .../src/graph/impls/cm/xattn_gemm_qk.cm | 112 ++ 8 files changed, 2036 insertions(+), 2 deletions(-) create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp new file mode 100644 index 00000000000000..879c9bda50359c --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp @@ -0,0 +1,1107 @@ +/* + * Copyright (c) 2020-2023, Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +#include +#include + +#if defined(SHIM) || defined(CMRT_EMU) +#define ATTR +#define ATTR_BUF +#define CM_LOCAL_BARRIER 0x20 +#include "emu/block2d.h" +#else +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +#define MYMIN(x, y) ((x) < (y) ? (x) : (y)) + +template +void show(const vector mat) { + printf("vector [%d]:\n[", N); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[n]); + } + printf("]\n"); +} + +template +void show_i(const vector mat) { + printf("vector [%d]:\n[", N); + for(int n = 0; n < N; n ++) { + printf("%d,", mat[n]); + } + printf("]\n"); +} + +template +void show(const matrix mat) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} + +template +void show(const matrix_ref mat) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} + +template +CM_INLINE void Transpose_8x8(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); + temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); + temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); + temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); + temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); + temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); + temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); + temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); + + out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); + out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); + out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); + out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); + out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); + out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); + out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); + out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); +} + +template +CM_INLINE void Transpose_8x32(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 0), out.template select<8, 1, 8, 1>( 0, 0)); + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 8), out.template select<8, 1, 8, 1>( 8, 0)); + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 16), out.template select<8, 1, 8, 1>(16, 0)); + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 24), out.template select<8, 1, 8, 1>(24, 0)); +} + +template +CM_INLINE void Transpose_4x32(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row(0) = in.template select<4, 1, 8, 4>(0, 0); + temp.row(1) = in.template select<4, 1, 8, 4>(0, 1); + temp.row(2) = in.template select<4, 1, 8, 4>(0, 2); + temp.row(3) = in.template select<4, 1, 8, 4>(0, 3); + + out.row( 0) = temp.template select<1, 1, 4, 8>(0, 0); + out.row( 1) = temp.template select<1, 1, 4, 8>(1, 0); + out.row( 2) = temp.template select<1, 1, 4, 8>(2, 0); + out.row( 3) = temp.template select<1, 1, 4, 8>(3, 0); + out.row( 4) = temp.template select<1, 1, 4, 8>(0, 1); + out.row( 5) = temp.template select<1, 1, 4, 8>(1, 1); + out.row( 6) = temp.template select<1, 1, 4, 8>(2, 1); + out.row( 7) = temp.template select<1, 1, 4, 8>(3, 1); + out.row( 8) = temp.template select<1, 1, 4, 8>(0, 2); + out.row( 9) = temp.template select<1, 1, 4, 8>(1, 2); + out.row(10) = temp.template select<1, 1, 4, 8>(2, 2); + out.row(11) = temp.template select<1, 1, 4, 8>(3, 2); + out.row(12) = temp.template select<1, 1, 4, 8>(0, 3); + out.row(13) = temp.template select<1, 1, 4, 8>(1, 3); + out.row(14) = temp.template select<1, 1, 4, 8>(2, 3); + out.row(15) = temp.template select<1, 1, 4, 8>(3, 3); + out.row(16) = temp.template select<1, 1, 4, 8>(0, 4); + out.row(17) = temp.template select<1, 1, 4, 8>(1, 4); + out.row(18) = temp.template select<1, 1, 4, 8>(2, 4); + out.row(19) = temp.template select<1, 1, 4, 8>(3, 4); + out.row(20) = temp.template select<1, 1, 4, 8>(0, 5); + out.row(21) = temp.template select<1, 1, 4, 8>(1, 5); + out.row(22) = temp.template select<1, 1, 4, 8>(2, 5); + out.row(23) = temp.template select<1, 1, 4, 8>(3, 5); + out.row(24) = temp.template select<1, 1, 4, 8>(0, 6); + out.row(25) = temp.template select<1, 1, 4, 8>(1, 6); + out.row(26) = temp.template select<1, 1, 4, 8>(2, 6); + out.row(27) = temp.template select<1, 1, 4, 8>(3, 6); + out.row(28) = temp.template select<1, 1, 4, 8>(0, 7); + out.row(29) = temp.template select<1, 1, 4, 8>(1, 7); + out.row(30) = temp.template select<1, 1, 4, 8>(2, 7); + out.row(31) = temp.template select<1, 1, 4, 8>(3, 7); +} + +template +CM_INLINE void Transpose_32x32(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row( 0) = in.template select<8, 1, 4, 8>( 0, 0); + temp.row( 1) = in.template select<8, 1, 4, 8>( 8, 0); + temp.row( 2) = in.template select<8, 1, 4, 8>(16, 0); + temp.row( 3) = in.template select<8, 1, 4, 8>(24, 0); + temp.row( 4) = in.template select<8, 1, 4, 8>( 0, 1); + temp.row( 5) = in.template select<8, 1, 4, 8>( 8, 1); + temp.row( 6) = in.template select<8, 1, 4, 8>(16, 1); + temp.row( 7) = in.template select<8, 1, 4, 8>(24, 1); + temp.row( 8) = in.template select<8, 1, 4, 8>( 0, 2); + temp.row( 9) = in.template select<8, 1, 4, 8>( 8, 2); + temp.row(10) = in.template select<8, 1, 4, 8>(16, 2); + temp.row(11) = in.template select<8, 1, 4, 8>(24, 2); + temp.row(12) = in.template select<8, 1, 4, 8>( 0, 3); + temp.row(13) = in.template select<8, 1, 4, 8>( 8, 3); + temp.row(14) = in.template select<8, 1, 4, 8>(16, 3); + temp.row(15) = in.template select<8, 1, 4, 8>(24, 3); + temp.row(16) = in.template select<8, 1, 4, 8>( 0, 4); + temp.row(17) = in.template select<8, 1, 4, 8>( 8, 4); + temp.row(18) = in.template select<8, 1, 4, 8>(16, 4); + temp.row(19) = in.template select<8, 1, 4, 8>(24, 4); + temp.row(20) = in.template select<8, 1, 4, 8>( 0, 5); + temp.row(21) = in.template select<8, 1, 4, 8>( 8, 5); + temp.row(22) = in.template select<8, 1, 4, 8>(16, 5); + temp.row(23) = in.template select<8, 1, 4, 8>(24, 5); + temp.row(24) = in.template select<8, 1, 4, 8>( 0, 6); + temp.row(25) = in.template select<8, 1, 4, 8>( 8, 6); + temp.row(26) = in.template select<8, 1, 4, 8>(16, 6); + temp.row(27) = in.template select<8, 1, 4, 8>(24, 6); + temp.row(28) = in.template select<8, 1, 4, 8>( 0, 7); + temp.row(29) = in.template select<8, 1, 4, 8>( 8, 7); + temp.row(30) = in.template select<8, 1, 4, 8>(16, 7); + temp.row(31) = in.template select<8, 1, 4, 8>(24, 7); + + out.row( 0) = temp.template select<4, 1, 8, 4>( 0, 0); + out.row( 1) = temp.template select<4, 1, 8, 4>( 4, 0); + out.row( 2) = temp.template select<4, 1, 8, 4>( 8, 0); + out.row( 3) = temp.template select<4, 1, 8, 4>(12, 0); + out.row( 4) = temp.template select<4, 1, 8, 4>(16, 0); + out.row( 5) = temp.template select<4, 1, 8, 4>(20, 0); + out.row( 6) = temp.template select<4, 1, 8, 4>(24, 0); + out.row( 7) = temp.template select<4, 1, 8, 4>(28, 0); + out.row( 8) = temp.template select<4, 1, 8, 4>( 0, 1); + out.row( 9) = temp.template select<4, 1, 8, 4>( 4, 1); + out.row(10) = temp.template select<4, 1, 8, 4>( 8, 1); + out.row(11) = temp.template select<4, 1, 8, 4>(12, 1); + out.row(12) = temp.template select<4, 1, 8, 4>(16, 1); + out.row(13) = temp.template select<4, 1, 8, 4>(20, 1); + out.row(14) = temp.template select<4, 1, 8, 4>(24, 1); + out.row(15) = temp.template select<4, 1, 8, 4>(28, 1); + out.row(16) = temp.template select<4, 1, 8, 4>( 0, 2); + out.row(17) = temp.template select<4, 1, 8, 4>( 4, 2); + out.row(18) = temp.template select<4, 1, 8, 4>( 8, 2); + out.row(19) = temp.template select<4, 1, 8, 4>(12, 2); + out.row(20) = temp.template select<4, 1, 8, 4>(16, 2); + out.row(21) = temp.template select<4, 1, 8, 4>(20, 2); + out.row(22) = temp.template select<4, 1, 8, 4>(24, 2); + out.row(23) = temp.template select<4, 1, 8, 4>(28, 2); + out.row(24) = temp.template select<4, 1, 8, 4>( 0, 3); + out.row(25) = temp.template select<4, 1, 8, 4>( 4, 3); + out.row(26) = temp.template select<4, 1, 8, 4>( 8, 3); + out.row(27) = temp.template select<4, 1, 8, 4>(12, 3); + out.row(28) = temp.template select<4, 1, 8, 4>(16, 3); + out.row(29) = temp.template select<4, 1, 8, 4>(20, 3); + out.row(30) = temp.template select<4, 1, 8, 4>(24, 3); + out.row(31) = temp.template select<4, 1, 8, 4>(28, 3); +} + +// group_count: M = group_count * group_size, group_size is the element count of current reduction +// op: 0-max, 1-sum +// M: before reduce element count, N: row count, stop: element count M must be larger than +template +CM_INLINE constexpr auto reduce2d(matrix_ref src) { + constexpr int group_size = M / group_count; + if constexpr (N > stop) { + matrix result; + // half of group will be reduced + constexpr int new_group_size = group_size / 2; + constexpr int new_group_count = group_count * 2; +#pragma unroll + for (int i = 0; i < N / 2; i++) { + matrix new_top, new_bot; + auto top = src.row(2 * i + 0).format(); + auto bot = src.row(2 * i + 1).format(); + constexpr int v_stride = new_group_count == 2 ? 1 : 2; + + new_top.select(0) = top.select(0); + new_top.select(new_group_count / 2) = bot.select(0); + new_bot.select(0) = top.select(1); + new_bot.select(new_group_count / 2) = bot.select(1); + if constexpr (op == 0) { + result[i] = cm_max(new_top.format(), new_bot.format()); + } else { + result[i] = (new_top.format() + new_bot.format()); + } + } + + return reduce2d(result); + } else { + matrix dst = src; + return dst; + } +} + +template +CM_INLINE void read_1d(vector_ref out, svmptr_t base) { + cm_ptr_block_read((T*)base, out); +} + +template +CM_INLINE void read_2d(matrix_ref out, svmptr_t base, uint pitch) { +#pragma unroll + for (int i = 0; i < out.n_rows(); i++, base += pitch) { + cm_ptr_block_read((T*)base, out.row(i)); + } +} + +template +CM_INLINE void write_2d(matrix_ref out, svmptr_t base, uint pitch) { +#pragma unroll + for (int i = 0; i < out.n_rows(); i++, base += pitch) { + cm_ptr_block_write((TSRC*)base, out.row(i)); + } +} + +template +CM_INLINE void write_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { +#pragma unroll + for (int i = 0; i < out.n_rows(); i++, offset += pitch) { + cm_store(base, offset, out.row(i).format()); + } +} + + +#if USE_KQ == 1 && (BLOCK_SG_M == 64 && BLOCK_SG_N == 32) +// register tile: [8, 2] aka[(8*8,16), (16, 16*2)] +// src_a is key, src_b is query +CM_INLINE void gemm_kq_64x32_xe2(uint id_wg_m, uint id_wg_n, uint hq, uint slm, svmptr_t key_cache, svmptr_t query, svmptr_t block_indices ATTR, svmptr_t block_indices_begins ATTR, svmptr_t kq_max ATTR, svmptr_t kq_max_wg ATTR, svmptr_t kq_exp_partial_sum ATTR, +uint M, uint N, uint K, uint query_stride, uint q_start_strided) { + constexpr int SG_SIZE = 16; + constexpr int BLOCK_WG_K = 64; // same in sg +#ifndef BLOCK_SG_M + #define BLOCK_SG_M 64 + #define BLOCK_SG_N 32 + #define SG_M 4 + #define SG_N 4 + #define HEAD_SIZE 128 + #define KV_BLOCK_SIZE 256 + #define STRIDE 16 +#endif + // xehpg DPAS spec: dst: [8, 8], repeat: 1~8, depth: 8 + static constexpr int REPEAT = 8; + static constexpr int DEPTH = 8; + static constexpr int BLOCK_REG_M = REPEAT; + static constexpr int BLOCK_REG_N = SG_SIZE; + static constexpr int BLOCK_DPAS_C = BLOCK_REG_M * BLOCK_REG_N; + static constexpr int VNNI = sizeof(half); + static constexpr int BLOCK_REG_K = DEPTH * sizeof(int) / VNNI; + static constexpr int BLOCK_REG_A = BLOCK_REG_M * BLOCK_REG_K; + static constexpr int BLOCK_REG_B = BLOCK_REG_N * BLOCK_REG_K; + static constexpr int BLOCK_WG_M = SG_M * BLOCK_SG_M; + static constexpr int BLOCK_WG_N = SG_N * BLOCK_SG_N; + // register blocking + static constexpr int REG_M = BLOCK_SG_M / BLOCK_REG_M; + static constexpr int REG_N = BLOCK_SG_N / BLOCK_REG_N; + static constexpr int REG_K = BLOCK_WG_K / BLOCK_REG_K; + static constexpr int REG_MN = REG_M * REG_N; + static constexpr int KEY_LINES_PER_LOAD = KV_BLOCK_SIZE / STRIDE; + + matrix acc = 0; // --> 64*2 regs + uint id_sg_n = cm_local_id(0); + uint id_sg_m = cm_local_id(1); + uint id_sg_mn = id_sg_m * SG_N + id_sg_n; + + static_assert(REG_N == 2, "block_2d_desc for b is manually unrolled by 2"); + static_assert(HEAD_SIZE % BLOCK_WG_K == 0, "K dimension must be multiple of BLOCK_WG_K"); + static_assert(KV_BLOCK_SIZE == 256, "block size of key(key_cache) should be 256"); + uint M_block = (M + BLOCK_WG_M - 1) / BLOCK_WG_M; + uint N_aligned = (N + BLOCK_WG_N - 1) / BLOCK_WG_N * BLOCK_WG_N; + uint M_block_aligned = M_block * (BLOCK_WG_M / (BLOCK_SIZE / STRIDE)); + const uint block_size_div_stride = BLOCK_SIZE / STRIDE; + constexpr half log2e = 1.4426950408889634f; + static_assert(BLOCK_SG_M / block_size_div_stride == 8, "BLOCK_SG_M / block_size_div_stride should be 8"); + static_assert(BLOCK_SG_N == 32, "BLOCK_SG_N should be 32"); + +#if IS_CAUSAL == 1 + if ((int)(id_wg_m * BLOCK_WG_M) >= ((int)id_wg_n + 1) * BLOCK_WG_N + q_start_strided) { + // fill -inf -> max in group, 0 -> exp_sum to make compensation work + { + // current max -> mem + vector max_n = -60000; + // kq_max_wg: [b, hq, M/BLOCK_WG_M, N_aligned] + uint offset = (id_wg_m * N_aligned + id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_n.format()); + } + { + // store + matrix sum_t = 0; + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, N - 1, (uint)(M_block_aligned * sizeof(half) - 1), (uint)(M_block_aligned * sizeof(half) - 1), + (int)((id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) / block_size_div_stride), (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) }; + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + } + + return; + } +#endif + // assume block index coming from 0 in block_indices_begins + int block_index_begin = ((int*)block_indices_begins)[0]; + int* block_indices_p = (int*)block_indices + block_index_begin; + int b_adjacent_between_head = query_stride / STRIDE; + // N[0:16*2]xK[0:16] + lsc::block_2d_desc desc_b0{ query, N - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head / 2, (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) }; + // prefetch B + static constexpr int SG_MN = SG_M * SG_N; + lsc::block_2d_desc desc_prefetch_b{query, N - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head, (int)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) }; + // N[0:16*2]xK[0:16] --> 8+8 regs + matrix b0, b1; + + // M[0:16]xK[0:32] + uint block_idx = (uint)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * STRIDE / KV_BLOCK_SIZE; + uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // M[16:32]xK[0:32] + offset = block_indices_p[block_idx + 1] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // M[32:48]xK[0:32] + offset = block_indices_p[block_idx + 2] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a2{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // M[48:64]xK[0:32] + offset = block_indices_p[block_idx + 3] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a3{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // prefetch A + block_idx = (uint)(id_wg_m * BLOCK_WG_M + id_sg_mn * (BLOCK_WG_M / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + static_assert(BLOCK_WG_M / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); + lsc::block_2d_desc desc_prefetch_a{ key_cache + offset, BLOCK_WG_M / SG_MN - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // 0~2 M[:]xK[0:16] 2~4 K[16:32] --> 32 * 2 regs + matrix a0, a1, a2, a3; + + // warmup + // prefetch + cm_prefetch(desc_prefetch_b); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + cm_prefetch(desc_prefetch_a); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load b: N[0:16]xK[0:16] + cm_load(b0.row(0).format(), desc_b0); + cm_load(b0.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + cm_sbarrier(1); + + auto dot = [&](matrix_ref A0, matrix_ref A1, matrix_ref A2, matrix_ref A3, matrix_ref B) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)(reg_m * REG_N + reg_n)) = cm_dpas(acc.row((ushort)(reg_m * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A0.row((ushort)reg_m).format()); + } +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)((reg_m + 2) * REG_N + reg_n)) = cm_dpas(acc.row((ushort)((reg_m + 2) * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A1.row((ushort)reg_m).format()); + } +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)((reg_m + 4) * REG_N + reg_n)) = cm_dpas(acc.row((ushort)((reg_m + 4) * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A2.row((ushort)reg_m).format()); + } +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)((reg_m + 6) * REG_N + reg_n)) = cm_dpas(acc.row((ushort)((reg_m + 6) * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A3.row((ushort)reg_m).format()); + } + } + }; + + for (uint s = 0; s < STRIDE; s++) { + #pragma unroll + for (uint hs = 0; hs < HEAD_SIZE / BLOCK_WG_K; hs++) { + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) + desc_prefetch_b.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head); + else + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + + // load b: N[0:16*2]xK[16:32] + cm_load(b1.row(0).format(), desc_b0); + cm_load(b1.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + + // load a: M[0:16*4]xK[0:32] + cm_load(a0.format(), desc_a0); + cm_load(a1.format(), desc_a1); + cm_load(a2.format(), desc_a2); + cm_load(a3.format(), desc_a3); + + desc_a0.set_block_x(desc_a0.get_block_x() + 32); + desc_a1.set_block_x(desc_a1.get_block_x() + 32); + desc_a2.set_block_x(desc_a2.get_block_x() + 32); + desc_a3.set_block_x(desc_a3.get_block_x() + 32); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(), a1.select<2, 1, BLOCK_REG_A, 1>(), + a2.select<2, 1, BLOCK_REG_A, 1>(), a3.select<2, 1, BLOCK_REG_A, 1>(), + b0); + + // load b: N[0:16*2]xK[32:48] + cm_load(b0.row(0).format(), desc_b0); + cm_load(b0.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(2), a1.select<2, 1, BLOCK_REG_A, 1>(2), + a2.select<2, 1, BLOCK_REG_A, 1>(2), a3.select<2, 1, BLOCK_REG_A, 1>(2), + b1); + + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load b: N[0:16*2]xK[48:64] + cm_load(b1.row(0).format(), desc_b0); + cm_load(b1.row(1).format(), desc_b0); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) + desc_b0.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head / 2); + else + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + // load a: M[0:32]xK[32:64] + cm_load(a0.format(), desc_a0); + cm_load(a1.format(), desc_a1); + cm_load(a2.format(), desc_a2); + cm_load(a3.format(), desc_a3); + desc_a0.set_block_x(desc_a0.get_block_x() + 32); + desc_a1.set_block_x(desc_a1.get_block_x() + 32); + desc_a2.set_block_x(desc_a2.get_block_x() + 32); + desc_a3.set_block_x(desc_a3.get_block_x() + 32); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(), a1.select<2, 1, BLOCK_REG_A, 1>(), + a2.select<2, 1, BLOCK_REG_A, 1>(), a3.select<2, 1, BLOCK_REG_A, 1>(), + b0); + + // load b: N[0:16*4]xK[0:16] + cm_load(b0.row(0).format(), desc_b0); + cm_load(b0.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(2), a1.select<2, 1, BLOCK_REG_A, 1>(2), + a2.select<2, 1, BLOCK_REG_A, 1>(2), a3.select<2, 1, BLOCK_REG_A, 1>(2), + b1); + cm_sbarrier(0); + cm_sbarrier(1); + } + } + + cm_sbarrier(0); + + matrix acc_half; +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M; reg_m++) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { + acc_half.select(reg_m * BLOCK_REG_M, reg_n * BLOCK_REG_N) = + acc.row(reg_m * REG_N + reg_n) * float{INV_S}; + } + } + + // if N(aka query) has tails, the following will not change the accuracy: + // gemm will compute results for the padding N(all should be zeros), the kq_max/kq_max_wg/kq_exp_partial_sum are along the query dimension and + // the results can be dropped in the future stage. To simplify the logic, the size of kq_max/kq_max_wg/kq_exp_partial_sum must be enough to hold + // all tails + padding results. + int m_start = (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M); + m_start = MYMIN(m_start, M); + int m_end = MYMIN(m_start + BLOCK_SG_M, M); + int valid_m = m_end - m_start; + matrix sum_t; + vector seq; + cmtl::cm_vector_assign(seq.select_all(), 0, 1); +#if IS_CAUSAL == 1 + bool skip_mask = false; + // in streaming scenario, the past kvcache length may be arbitrary so valid causal mask of a workgroup may start at arbitrary position + if (m_end <= (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + q_start_strided)) { + // all are inside causal mask == 1 + skip_mask = true; + } else { + vector n_pos = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + q_start_strided) + seq; + #pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + SIMD_IF_BEGIN (m_start + reg_m > n_pos) { + acc_half.row(reg_m) = half{-60000}; + } SIMD_IF_END; + } + } +#else + bool skip_mask = true; +#endif + // case for valid_m == BLOCK_SG_M but skip_mask == false which needs to handle causal mask: + // query = 128 * 2 + 1, key = 256 * 2 + if (valid_m == BLOCK_SG_M && skip_mask) { + vector max_n = acc_half.row(0); + #pragma unroll + for (uint reg_m = 1; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + max_n = cm_max(max_n, acc_half.row(reg_m)); + } + + { + uint slm_offset = (id_sg_m * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + // current max -> slm + cm_slm_block_write(slm, slm_offset, max_n.format()); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_barrier(); + // max inside wg + cm_slm_block_read(slm, id_sg_n * BLOCK_SG_N * (uint)sizeof(half), max_n.format()); + vector tmp; + #pragma unroll + for (uint i = 1; i < SG_M; i++) { + slm_offset = (i * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + cm_slm_block_read(slm, slm_offset, tmp.format()); + max_n = cm_max(max_n, tmp); + } + // max across wg + // kq_max: [b, hq, N_aligned] + vector max_offsets = (id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + seq) * (uint)sizeof(half); + cm_ptr_atomic((half*)kq_max, max_offsets, max_n); + + // current max -> mem + // kq_max_wg: [b, hq, M/BLOCK_WG_M, N_aligned] + uint offset = (id_wg_m * N_aligned + id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_n.format()); + } + { + // kq_exp_partial_sum: [b, hq, N_aligned, M/(BLOCK_SIZE/STRIDE)] + matrix sum; + #pragma unroll + for (uint m = 0; m < BLOCK_SG_M / block_size_div_stride; m++) { + sum.row(m) = cm_exp((acc_half.row(m * block_size_div_stride) - max_n) * log2e); + #pragma unroll + for (uint sub_m = 1; sub_m < block_size_div_stride; sub_m++) { + uint real_m = m * block_size_div_stride + sub_m; + sum.row(m) += cm_exp((acc_half.row(real_m) - max_n) * log2e); + } + } + + Transpose_8x32(sum, sum_t); + } + } else { + // M tails + vector max_n = -60000; + for (uint reg_m = 0; reg_m < valid_m; reg_m++) { + max_n = cm_max(max_n, acc_half.row(reg_m)); + } + + { + uint slm_offset = (id_sg_m * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + // current max -> slm + cm_slm_block_write(slm, slm_offset, max_n.format()); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_barrier(); + // max inside wg + cm_slm_block_read(slm, id_sg_n * BLOCK_SG_N * (uint)sizeof(half), max_n.format()); + vector tmp; + #pragma unroll + for (uint i = 1; i < SG_M; i++) { + slm_offset = (i * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + cm_slm_block_read(slm, slm_offset, tmp.format()); + max_n = cm_max(max_n, tmp); + } + // max across wg + // kq_max: [b, hq, N_aligned] + vector max_offsets = (id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + seq) * (uint)sizeof(half); + cm_ptr_atomic((half*)kq_max, max_offsets, max_n); + + // current max -> mem + // kq_max_wg: [b, hq, M/BLOCK_WG_M, N_aligned] + uint offset = (id_wg_m * N_aligned + id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_n.format()); + } + { + // kq_exp_partial_sum: [b, hq, N_aligned, M/(BLOCK_SIZE/STRIDE)] + matrix sum = 0; + vector n_pos = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + q_start_strided) + seq; + #pragma unroll + for (uint m = 0; m < BLOCK_SG_M / block_size_div_stride; m++) { + #pragma unroll + for (uint sub_m = 0; sub_m < block_size_div_stride; sub_m++) { + uint real_m = m * block_size_div_stride + sub_m; +#if IS_CAUSAL == 1 + // to following case: + // 0 0 1 1 + // 0 0 0 1 + // the acc value of first column should be -inf --> max(first column) == -inf --> exp(first column - max) == 1, this is incorrect + // so need to use simd_if to detect per element state + SIMD_IF_BEGIN ((m_start + real_m <= n_pos) & (real_m < valid_m)) { + sum.row(m) += cm_exp((acc_half.row(real_m) - max_n) * log2e); + } SIMD_IF_END; +#else + if (real_m < valid_m) + sum.row(m) += cm_exp((acc_half.row(real_m) - max_n) * log2e); +#endif + } + } + Transpose_8x32(sum, sum_t); + } + } + // store + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, N - 1, (uint)(M_block_aligned * sizeof(half) - 1), (uint)(M_block_aligned * sizeof(half) - 1), + (int)((id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) / block_size_div_stride), (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) }; + cm_store(desc_c, sum_t.select<8, 1, 8, 1>( 0).format()); + cm_store(desc_c, sum_t.select<8, 1, 8, 1>( 8).format()); + cm_store(desc_c, sum_t.select<8, 1, 8, 1>(16).format()); + cm_store(desc_c, sum_t.select<8, 1, 8, 1>(24).format()); +} +#endif + +#if 1 || (BLOCK_SG_M == 64 && BLOCK_SG_N == 32) +// const static int channels_reduce_32[] = { 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, +// 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31}; +// src_a is query, src_b is key +CM_INLINE void gemm_qk_64x32_xe2(uint id_wg_m, uint id_wg_n, uint hq, uint slm, svmptr_t key_cache ATTR, svmptr_t query ATTR, svmptr_t block_indices ATTR, svmptr_t block_indices_begins ATTR, svmptr_t kq_max_wg ATTR, svmptr_t kq_exp_partial_sum ATTR, +uint M, uint N, uint K, uint query_stride, uint q_start_strided) { + constexpr int SG_SIZE = 16; + constexpr int BLOCK_WG_K = 64; // same in sg +#ifndef BLOCK_SG_M + #define BLOCK_SG_M 64 + #define BLOCK_SG_N 32 + #define SG_M 4 + #define SG_N 4 + #define HEAD_SIZE 128 + #define KV_BLOCK_SIZE 256 + #define STRIDE 16 +#endif + // xehpg DPAS spec: dst: [8, 8], repeat: 1~8, depth: 8 + static constexpr int REPEAT = 8; + static constexpr int DEPTH = 8; + static constexpr int BLOCK_REG_M = REPEAT; + static constexpr int BLOCK_REG_N = SG_SIZE; + static constexpr int BLOCK_DPAS_C = BLOCK_REG_M * BLOCK_REG_N; + static constexpr int VNNI = sizeof(half); + static constexpr int BLOCK_REG_K = DEPTH * sizeof(int) / VNNI; + static constexpr int BLOCK_REG_A = BLOCK_REG_M * BLOCK_REG_K; + static constexpr int BLOCK_REG_B = BLOCK_REG_N * BLOCK_REG_K; + static constexpr int BLOCK_WG_M = SG_M * BLOCK_SG_M; + static constexpr int BLOCK_WG_N = SG_N * BLOCK_SG_N; + // register blocking + static constexpr int REG_M = BLOCK_SG_M / BLOCK_REG_M; + static constexpr int REG_N = BLOCK_SG_N / BLOCK_REG_N; + static constexpr int REG_K = BLOCK_WG_K / BLOCK_REG_K; + static constexpr int REG_MN = REG_M * REG_N; + static constexpr int KEY_LINES_PER_LOAD = KV_BLOCK_SIZE / STRIDE; + + matrix acc = 0; // --> 64*2 regs + uint id_sg_n = cm_local_id(0); + uint id_sg_m = cm_local_id(1); + uint id_sg_mn = id_sg_m * SG_N + id_sg_n; + + static_assert(REG_N == 2, "block_2d_desc for b is manually unrolled by 2"); + static_assert(HEAD_SIZE % BLOCK_WG_K == 0, "K dimension must be multiple of BLOCK_WG_K"); + static_assert(KV_BLOCK_SIZE == 256, "block size of key(key_cache) should be 256"); + uint N_block = (N + BLOCK_WG_N - 1) / BLOCK_WG_N; + uint M_aligned = (M + BLOCK_WG_M - 1) / BLOCK_WG_M * BLOCK_WG_M; + uint K_block_pad = N_block * (BLOCK_WG_N / (BLOCK_SIZE / STRIDE)); + const uint block_size_div_stride = BLOCK_SIZE / STRIDE; + constexpr half log2e = 1.4426950408889634f; + //static_assert(BLOCK_SG_M / block_size_div_stride == 8, "BLOCK_SG_M / block_size_div_stride should be 8"); + +#if IS_CAUSAL == 1 + if ((int)(id_wg_n * BLOCK_WG_N) >= ((int)id_wg_m + 1) * BLOCK_WG_M + q_start_strided) { + // fill -inf -> max in group, 0 -> exp_sum to make compensation work + { + // current max -> mem + vector max_m = -60000; + // kq_max_wg: [b, hq, N/BLOCK_WG_N, M_aligned] + uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_m.format()); + } + { + // store + matrix sum_t = 0; + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(half) - 1), (uint)(K_block_pad * sizeof(half) - 1), + (int)((id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) / block_size_div_stride), (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + } + + return; + } +#endif + // assume block index coming from 0 in block_indices_begins + int block_index_begin = ((int*)block_indices_begins)[0]; + int* block_indices_p = (int*)block_indices + block_index_begin; + int b_adjacent_between_head = query_stride / STRIDE; + // M[0:16*2]xK[0:16] + lsc::block_2d_desc desc_a{ query, M - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head, (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; + // prefetch A + static constexpr int SG_MN = SG_M * SG_N; + lsc::block_2d_desc desc_prefetch_a{query, M - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head, (int)(id_wg_m * BLOCK_WG_M + id_sg_mn * (BLOCK_WG_M / SG_MN)) }; + // M[0:16*2]xK[0:16] --> 8+8 regs + matrix a0; + + // M[0:16]xK[0:32] + uint block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * STRIDE / KV_BLOCK_SIZE; +#if USE_INT8 + uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); + lsc::block_2d_desc desc_b0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), + 0, 0 }; + uint scale_offset0 = offset + KV_BLOCK_SIZE * HEAD_SIZE; + offset = block_indices_p[block_idx + 1] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); + lsc::block_2d_desc desc_b1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), + 0, 0 }; + uint scale_offset1 = offset + KV_BLOCK_SIZE * HEAD_SIZE; + // prefetch B + block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); + static_assert(BLOCK_WG_N / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); + lsc::block_2d_desc desc_prefetch_b{ key_cache + offset, BLOCK_WG_N / SG_MN - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), + 0, 0 }; + + // N[:]xK[0:32] --> 16 * 1 regs + matrix b0_up_s8, b0_down_s8, b1_up_s8, b1_down_s8; + matrix b0; // --> 16 regs + matrix scales, zps; + matrix scales_block, zps_block; +#else + uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_b0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + offset = block_indices_p[block_idx + 1] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_b1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // prefetch B + block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + static_assert(BLOCK_WG_N / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); + lsc::block_2d_desc desc_prefetch_b{ key_cache + offset, BLOCK_WG_N / SG_MN - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // 0~2 M[:]xK[0:16] 2~4 K[16:32] --> 32 * 2 regs + matrix b0, b1; +#endif + + // warmup + // prefetch + cm_prefetch(desc_prefetch_b); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + cm_prefetch(desc_prefetch_a); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load b: N[0:16]xK[0:16] +#if USE_INT8 + scales_block[0].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset0); + zps_block[0].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset0 + KV_BLOCK_SIZE * (uint)sizeof(half)); + scales_block[1].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset1); + zps_block[1].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset1 + KV_BLOCK_SIZE * (uint)sizeof(half)); + + cm_load(b0_up_s8.format(), desc_b0); + cm_load(b0_down_s8.format(), desc_b1); +#else + cm_load(b0[0].format(), desc_b0); + cm_load(b0[1].format(), desc_b1); +#endif + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + + cm_sbarrier(1); + +#if USE_INT8 + auto dec = [&](vector B0_i8, vector B1_i8, matrix_ref B0) { +#pragma unroll + for (int n = 0; n < REG_N; n++) { + auto b = B0[n].format(); +#pragma unroll + for (int m = 0; m < 8; m++) { + auto b_row = b[m]; + vector d0; + if (n == 0) + d0 = B0_i8.format()[m / 2].select<16, 2>(m % 2); + else + d0 = B1_i8.format()[m / 2].select<16, 2>(m % 2); + b_row.format() = d0.format(); + b_row *= half{32768.0}; + b_row *= half{512.0}; + b_row = (b_row - zps[n]) * scales[n]; + } + } + }; +#endif + auto dot = [&](matrix A, matrix B) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M; reg_m++) { + acc.row((ushort)(reg_m * REG_N + reg_n)) = cm_dpas(acc.row((ushort)(reg_m * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A.row((ushort)reg_m).format()); + } + } + }; + + for (uint s = 0; s < STRIDE; s++) { +#if USE_INT8 + auto tmp = scales_block[0].select<16, 16>(s); + scales[0].select<16, 2>(0) = tmp; + scales[0].select<16, 2>(1) = scales[0].select<16, 2>(0); + tmp = scales_block[1].select<16, 16>(s); + scales[1].select<16, 2>(0) = tmp; + scales[1].select<16, 2>(1) = scales[1].select<16, 2>(0); + tmp = zps_block[0].select<16, 16>(s); + zps[0].select<16, 2>(0) = tmp; + zps[0].select<16, 2>(1) = zps[0].select<16, 2>(0); + tmp = zps_block[1].select<16, 16>(s); + zps[1].select<16, 2>(0) = tmp; + zps[1].select<16, 2>(1) = zps[1].select<16, 2>(0); +#endif + #pragma unroll + for (uint hs = 0; hs < HEAD_SIZE / BLOCK_WG_K; hs++) { + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) + desc_prefetch_a.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head); + else + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load a: M[0:16*4]xK[0:16] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + // load b: N[0:16*2]xK[16:32] +#if USE_INT8 + cm_load(b1_up_s8.format(), desc_b0); + cm_load(b1_down_s8.format(), desc_b1); + dec(b0_up_s8.format().select<64, 1>(), b0_down_s8.format().select<64, 1>(), b0); + dot(a0, b0); +#else + cm_load(b1[0].format(), desc_b0); + cm_load(b1[1].format(), desc_b1); + dot(a0, b0); +#endif + + desc_a.set_block_x(desc_a.get_block_x() + 16); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + + // load a: M[0:16*4]xK[16:32] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + +#if USE_INT8 + dec(b0_up_s8.format().select<64, 1>(64), b0_down_s8.format().select<64, 1>(64), b0); + dot(a0, b0); +#else + cm_load(b0[0].format(), desc_b0); + cm_load(b0[1].format(), desc_b1); + dot(a0, b1); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); +#endif + desc_a.set_block_x(desc_a.get_block_x() + 16); + + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load a: M[0:16*4]xK[32:48] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + + // load b: N[0:16*2]xK[32:64] +#if USE_INT8 + cm_load(b0_up_s8.format(), desc_b0); + cm_load(b0_down_s8.format(), desc_b1); + dec(b1_up_s8.format().select<64, 1>(), b1_down_s8.format().select<64, 1>(), b0); + dot(a0, b0); +#else + cm_load(b1[0].format(), desc_b0); + cm_load(b1[1].format(), desc_b1); + dot(a0, b0); +#endif + + desc_a.set_block_x(desc_a.get_block_x() + 16); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + + // load a: M[0:16*4]xK[48:64] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) { + desc_a.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head); + } else { + desc_a.set_block_x(desc_a.get_block_x() + 16); + } + +#if USE_INT8 + dec(b1_up_s8.format().select<64, 1>(64), b1_down_s8.format().select<64, 1>(64), b0); + dot(a0, b0); +#else + cm_load(b0[0].format(), desc_b0); + cm_load(b0[1].format(), desc_b1); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + dot(a0, b1); +#endif + + cm_sbarrier(0); + cm_sbarrier(1); + } + } + + cm_sbarrier(0); + + matrix acc_half; +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M; reg_m++) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { + acc_half.select(reg_m * BLOCK_REG_M, reg_n * BLOCK_REG_N) = + acc.row(reg_m * REG_N + reg_n) * float{INV_S}; + } + } + + // if M(aka query) has tails, the following will not change the accuracy: + // gemm will compute results for the padding M(all should be zeros), the kq_max/kq_max_wg/kq_exp_partial_sum are along the query dimension and + // the results can be dropped in the future stage. To simplify the logic, the size of kq_max/kq_max_wg/kq_exp_partial_sum must be enough to hold + // all tails + padding results. + int n_start = (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N); + n_start = MYMIN(n_start, N); + int n_end = MYMIN(n_start + BLOCK_SG_N, N); + int valid_n = n_end - n_start; + matrix sum_t; + vector seq_m; + cmtl::cm_vector_assign(seq_m.select_all(), 0, 1); + vector_ref seq = seq_m.select(); + vector n_pos = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) + seq; +#if IS_CAUSAL == 1 + bool skip_mask = false; + int m_start = (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M + q_start_strided); + // in streaming scenario, the past kvcache length may be arbitrary so valid causal mask of a workgroup may start at arbitrary position + if (n_end <= m_start) { + // all are inside causal mask == 1 + skip_mask = true; + } else { + #pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + SIMD_IF_BEGIN (n_pos > m_start + reg_m) { + acc_half.row(reg_m) = half{-60000}; + } SIMD_IF_END; + } + } +#else + bool skip_mask = true; +#endif + vector max_m; + if (valid_n != BLOCK_SG_N) { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + acc_half.row(reg_m).merge(half{-60000}, n_pos >= N); + } + } + max_m.select<32, 1>() = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>()).format(); + max_m.select<32, 1>(32) = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>(32)).format(); + + { + uint slm_offset = (id_sg_n * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(half); + // current max -> slm + cm_slm_block_write(slm, slm_offset, max_m.format()); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_barrier(); + // max inside wg + cm_slm_block_read(slm, id_sg_m * BLOCK_SG_M * (uint)sizeof(half), max_m.format()); + vector tmp; +#pragma unroll + for (uint i = 1; i < SG_N; i++) { + slm_offset = (i * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(half); + cm_slm_block_read(slm, slm_offset, tmp.format()); + max_m = cm_max(max_m, tmp); + } + // max across wg + // kq_max: [b, hq, M_aligned] + //auto max_offsets = (id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M + seq_m) * (uint)sizeof(half); + //cm_ptr_atomic((half*)kq_max, max_offsets, max_m); + + // current max -> mem + // kq_max_wg: [b, hq, N/BLOCK_WG_N, M_aligned] + uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_m.format()); + } + { + // kq_exp_partial_sum: [b, hq, M_aligned, N/(BLOCK_SIZE/STRIDE)] + if (valid_n == BLOCK_SG_N && skip_mask) { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + acc_half.row(reg_m) = cm_exp((acc_half.row(reg_m) - max_m[reg_m]) * log2e); + } + } else { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + acc_half.row(reg_m) = cm_exp((acc_half.row(reg_m) - max_m[reg_m]) * log2e); + // causal mask in the following case: + // block0(EU0) block1(EU1) + // 1 1 1 1 0 0 0 0 + // 1 1 1 1 1 0 0 0 + // 1 1 1 1 1 1 0 0 + // 1 1 1 1 1 1 1 0 + // the acc value of first row of block1 should be -inf --> max(first column) == -inf --> exp(first column - max) == 1, this is incorrect + // so need to use simd_if to detect per element state +#if IS_CAUSAL + SIMD_IF_BEGIN ((n_pos > m_start + reg_m) | (n_pos >= N)) { +#else + SIMD_IF_BEGIN (n_pos >= N) { +#endif + acc_half.row(reg_m) = 0; + } SIMD_IF_END; + } + } + sum_t.select<32, 1, 4, 1>( 0).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>( 0)).format(); + sum_t.select<32, 1, 4, 1>(32).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>(32)).format(); + } + // store + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(half) - 1), (uint)(K_block_pad * sizeof(half) - 1), + (int)((id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) / block_size_div_stride), (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; + cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 0).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 8).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(16).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(24).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(32).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(40).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(48).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(56).format()); +} +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp new file mode 100644 index 00000000000000..b26e97b2478b1d --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2020-2023, Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +#include +#include + +#include "sort.hpp" + + +// kq_max_wg: [b, hq, n_groups, q_stride_pad] +// kq_exp_partial_sum: [b, hq, q_stride_pad, k_block_pad] +// kq_sum: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] +CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_partial_sum, svmptr_t block_mask, uint q_stride, uint q_stride_pad, uint k_block_pad, float thresh, uint causal_start_index +#if DEBUG_ACC == 1 + , svmptr_t kq_sum +#endif +) { + constexpr int SG_SIZE = 16; +#ifndef BLOCK_SG_M + #define BLOCK_SG_M 64 + #define BLOCK_SG_N 32 + #define SG_M 2 + #define SG_N 4 + #define HEAD_SIZE 128 + #define KV_BLOCK_SIZE 256 + #define STRIDE 16 + #define BLOCK_SIZE 128 + #define BLOCK_SHARE_MAX 256 +#endif + + const int TOKEN_IN_BLOCK = (BLOCK_SIZE / STRIDE); + int m = m_block * TOKEN_IN_BLOCK; + vector max_m; + + const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; + kq_exp_partial_sum += m * k_block_pad * (int)sizeof(half); + kq_max_wg += m * (int)sizeof(half); + constexpr half log2e = 1.4426950408889634f; + matrix sum_m = 0; + matrix data; + int m_start = MYMIN(m, q_stride); + int m_end = MYMIN(m_start + TOKEN_SHARE_MAX, q_stride); + int valid_m = m_end - m_start; + if (valid_m == 0) return; + lsc::block_2d_desc desc_sum{ kq_exp_partial_sum, (uint)valid_m - 1, (uint)(k_block_pad * sizeof(half) - 1), (uint)(k_block_pad * sizeof(half) - 1), + 0, 0 }; + { + // find max: (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad + max_m = half{-60000}; + + for (int idx = 0; idx < k_block_pad / TOKEN_SHARE_MAX; idx++) { + vector max_m_in_group; + max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(half)); + max_m = cm_max(max_m, max_m_in_group); + } + } + // compensation: val*exp(local - global) + desc_sum.set_block_x(0); + for (int j = 0, idx = 0; j < k_block_pad; j += TOKEN_SHARE_MAX, idx++) { + vector max_m_in_group; + max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(half)); + cm_load(data.format(), desc_sum); + for (int i = 0; i < TOKEN_IN_BLOCK; i++) { + if (i < valid_m) { + data.row(i) *= cm_exp((max_m_in_group[i] - max_m[i]) * log2e); + sum_m.row(i) += data.row(i); + } + } + cm_store(desc_sum, data.format()); + desc_sum.set_block_x(desc_sum.get_block_x() + TOKEN_SHARE_MAX); + } + + // exp/sum + vector inv_sum_v; + for (int i = 0; i < TOKEN_IN_BLOCK; i++) { + if (i < valid_m) + inv_sum_v[i] = 1.0f / cm_sum(sum_m.row(i)); + else + inv_sum_v[i] = 0; + } + // compensation: sum(val*inv_sum_v) + vector sum_m_after_add = 0; + desc_sum.set_block_x(0); +#if DEBUG_ACC == 1 + kq_sum += m_block * k_block_pad * (int)sizeof(half); +#endif + for (int j = 0; j < k_block_pad; j += TOKEN_SHARE_MAX) { + cm_load(data.format(), desc_sum); + data.row(0) *= inv_sum_v[0]; + for (int i = 1; i < TOKEN_IN_BLOCK; i++) { + data.row(0) += data.row(i) * inv_sum_v[i]; + } + desc_sum.set_block_x(desc_sum.get_block_x() + TOKEN_SHARE_MAX); + sum_m_after_add += data.row(0); + cm_ptr_store((int*)kq_exp_partial_sum, j * (int)sizeof(half), data.row(0).format()); +#if DEBUG_ACC == 1 + cm_ptr_store((int*)kq_sum, j * (int)sizeof(half), data.row(0).format()); +#endif + } + auto thresh_act = cm_sum(sum_m_after_add) * thresh; + + // content of 8(aka stride) lines: + // line 0: score + // line 1: sorted value + // line 3: sorted index + // line 5: sorted tmp + // line 6: accumalative score + block_mask += m_block * k_block_pad; + auto score = kq_exp_partial_sum + 0 * k_block_pad * (int)sizeof(half); + auto sorted_value = kq_exp_partial_sum + 1 * k_block_pad * (int)sizeof(half); + auto sorted_index = kq_exp_partial_sum + 3 * k_block_pad * (int)sizeof(half); + auto sorted_tmp = kq_exp_partial_sum + 5 * k_block_pad * (int)sizeof(half); + auto acc_score = kq_exp_partial_sum + 6 * k_block_pad * (int)sizeof(half); + +#if IS_CAUSAL == 1 + auto score_p = (half*)score; + half s_0 = score_p[0]; + half s_causal = score_p[causal_start_index + m_block]; + half s_sum = s_0; + if (causal_start_index + m_block) s_sum += s_causal; + score_p[0] = -1; + score_p[causal_start_index + m_block] = -1; + sort(slm, score, sorted_value + 2 * sizeof(half), sorted_index + 2 * sizeof(half), sorted_tmp, k_block_pad); + uchar* block_mask_p = (uchar*)block_mask; + auto sorted_value_p = (half*)sorted_value; + auto sorted_index_p = (ushort*)sorted_index; + auto acc_score_p = (half*)acc_score; + sorted_value_p[0] = 0; + sorted_value_p[1] = s_sum; + block_mask_p[0] = 1; + block_mask_p[causal_start_index + m_block] = 1; + float sum_cur = s_sum; +#if DEBUG_ACC == 1 + acc_score_p[0] = 0; + acc_score_p[1] = 0; +#endif + for (int j = 2; j < k_block_pad - 2; j++) { +#if DEBUG_ACC == 1 + acc_score_p[j] = sum_cur; +#endif + if (sum_cur < thresh_act) { + block_mask_p[sorted_index_p[j]] = 1; + } else { +#if DEBUG_ACC != 1 + break; +#endif + } + sum_cur += sorted_value_p[j]; + } + +#else + + sort(slm, score, sorted_value, sorted_index, sorted_tmp, k_block_pad); + uchar* block_mask_p = (uchar*)block_mask; + auto sorted_value_p = (half*)sorted_value; + auto sorted_index_p = (ushort*)sorted_index; + auto acc_score_p = (half*)acc_score; + block_mask_p[0] = 1; + float sum_cur = 0; +#if DEBUG_ACC == 1 + acc_score_p[0] = 0; +#endif + for (int j = 0; j < k_block_pad - 1; j++) { + sum_cur += sorted_value_p[j]; +#if DEBUG_ACC == 1 + acc_score_p[j + 1] = sum_cur; +#endif + if (sum_cur < thresh_act) { + block_mask_p[sorted_index_p[j]] = 1; + } else { + block_mask_p[sorted_index_p[j]] = 1; +#if DEBUG_ACC != 1 + break; +#endif + } + } +#endif +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp new file mode 100644 index 00000000000000..f05396fac810d9 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2020-2023, Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +#include +#include + +#ifndef ATTR +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +#if 0 +template +void show(const vector mat, bool is_hex=true) { + printf("vector [%d]:\n[", N); + for(int n = 0; n < N; n ++) { + if (is_hex) + printf("%x,", mat[n]); + else + printf("%d,", mat[n]); + } + printf("]\n"); +} + +template +void show(const matrix mat) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} +#endif + +// https://gpuopen.com/download/Introduction_to_GPU_Radix_Sort.pdf +template +CM_INLINE void sort(uint slm, svmptr_t src, svmptr_t sorted_value, svmptr_t sorted_index, svmptr_t sort_tmp, uint n) { + const ushort THREADS = 32; + vector seq_u32; + vector seq; + cmtl::cm_vector_assign(seq.select_all(), 0, 1); + cmtl::cm_vector_assign(seq_u32.select_all(), 0, 1); + svmptr_t sorted_src = src; + svmptr_t sorted_tmp = sorted_value; + svmptr_t cur_idx = sorted_index; + svmptr_t cur_idx_tmp = sort_tmp; + uint iter = (n + THREADS - 1) / THREADS; + vector offset_src; + cmtl::cm_vector_assign(offset_src.select_all(), 0, iter); + auto f16_u16 = [] (vector_ref in, vector_ref out) { + static const ushort HIGH_BIT = 1 << 15; + auto in_u16 = in.format(); + auto mask = (in_u16 & HIGH_BIT) != 0; + vector m; + m.merge(ushort{0xffff}, HIGH_BIT, mask); + out = in_u16 ^ m; + }; + auto u16_f16 = [] (vector_ref in, vector_ref out) { + static const ushort HIGH_BIT = 1 << 15; + auto mask = (in & HIGH_BIT) != 0; + vector m; + m.merge(HIGH_BIT, ushort{0xffff}, mask); + out.format() = in ^ m; + }; + if constexpr(std::is_same::value) { + { + // f16 to u16 + vector data; + vector data_f; + int i; + for (i = 0; i + THREADS <= n / THREADS * THREADS; i += THREADS) { + data_f.format() = cm_ptr_load((int*)src, i * (int)sizeof(short)); + f16_u16(data_f, data); + cm_ptr_store((int*)src, i * (int)sizeof(short), data.format()); + } + if (i < n) { + auto pos = seq_u32 + i; + SIMD_IF_BEGIN (pos < n) { + data_f = cm_ptr_load((half*)src, pos * (uint)sizeof(half)); + f16_u16(data_f, data); + cm_ptr_store((ushort*)src, pos * (uint)sizeof(ushort), data); + } SIMD_IF_END; + } + } + } + { + // generate idx + vector data; + int i; + for (i = 0; i + THREADS <= n / THREADS * THREADS; i += THREADS) { + data = seq + i; + cm_ptr_store((int*)cur_idx, i * (int)sizeof(short), data.format()); + } + if (i < n) { + auto pos = seq_u32 + i; + data = seq + i; + + SIMD_IF_BEGIN (pos < n) { + cm_ptr_store((ushort*)cur_idx, pos * (uint)sizeof(ushort), data); + } SIMD_IF_END; + } + } + // 4bit per pass, 4 pass for f16 + for (int pass = 0; pass < 4; pass++) { + { + // slm layout: short [16][work items] = 16*32*2 bytes + vector data = 0; + for (int i = 0; i < 16 * THREADS * sizeof(ushort); i += 256) + cm_slm_block_write(slm, i, data); + } + { + // counting phase + vector data; + for (int i = 0; i < iter; i++) { + data = cm_ptr_load((ushort*)sorted_src, (offset_src + i) * (uint)sizeof(ushort), (offset_src + i) < n); + vector bits = 0xf - ((data >> (pass * 4)) & 0xf); + vector addr = bits * THREADS + seq; + vector total; + cm_slm_read(slm, addr, total); + total += 1; + cm_slm_write(slm, addr, total); + } + } + // { + // // prefix sum + // vector data; + // cm_slm_block_read(slm, 0, data); + // for (int i = 1; i < 16 * THREADS; i++) { + // data[i] += data[i - 1]; + // } + // data.select<16 * THREADS - 1, 1>(1) = data.select<16 * THREADS - 1, 1>(0); + // data[0] = 0; + // cm_slm_block_write(slm, 0, data); + // } + { + // prefix sum + vector local_prefix = 0; + + vector seq_prefix; + cmtl::cm_vector_assign(seq_prefix.select_all(), 0, THREADS); + + #pragma unroll + for (ushort i = 0; i < THREADS; i++) { + auto prev = local_prefix; + vector hist; + vector addr = seq_prefix + i; + cm_slm_read(slm, addr, hist); + local_prefix += hist; + cm_slm_write(slm, addr, prev); + } + // Hillis-Steele scan + vector local_tmp; + local_tmp.select<15, 1>(1) = local_prefix.select<15, 1>(1) + local_prefix.select<15, 1>(0); local_tmp[0] = local_prefix[0]; + local_prefix.select<14, 1>(2) = local_tmp.select<14, 1>(2) + local_tmp.select<14, 1>(0); local_prefix.select<2, 1>(0) = local_tmp.select<2, 1>(0); + local_tmp.select<12, 1>(4) = local_prefix.select<12, 1>(4) + local_prefix.select<12, 1>(0); local_tmp.select<4, 1>(0) = local_prefix.select<4, 1>(0); + local_prefix.select<8, 1>(8) = local_tmp.select<8, 1>(8) + local_tmp.select<8, 1>(0); local_prefix.select<8, 1>(0) = local_tmp.select<8, 1>(0); + vector data; + cm_slm_block_read(slm, 0, data); + #pragma unroll + for (int i = 1; i < 16; i++) { + data.select(i * THREADS) += local_prefix[i - 1]; + } + cm_slm_block_write(slm, 0, data); + } + { + // reorder + vector data; + for (int i = 0; i < iter; i++) { + data = cm_ptr_load((ushort*)sorted_src, (offset_src + i) * (uint)sizeof(ushort), (offset_src + i) < n); + vector bits = 0xf - ((data >> (pass * 4)) & 0xf); + vector addr = bits * THREADS + seq; + vector index; + cm_slm_read(slm, addr, index); + vector offset_i32 = index * (uint)sizeof(ushort); + cm_ptr_store((ushort*)sorted_tmp, offset_i32, data, index < n); + + data = cm_ptr_load((ushort*)cur_idx, (offset_src + i) * (uint)sizeof(ushort)); + cm_ptr_store((ushort*)cur_idx_tmp, offset_i32, data, (offset_src + i) < n); + + index += 1; + cm_slm_write(slm, addr, index); + } + auto tmp = sorted_src; + sorted_src = sorted_tmp; + sorted_tmp = tmp; + tmp = cur_idx; + cur_idx = cur_idx_tmp; + cur_idx_tmp = tmp; + } + } + { + // copy to output + vector data; + vector data_f; + int i; + for (i = 0; i + THREADS <= n / THREADS * THREADS; i += THREADS) { + data.format() = cm_ptr_load((int*)src, i * (int)sizeof(short)); + if constexpr(std::is_same::value) { + u16_f16(data, data_f); + cm_ptr_store((int*)sorted_value, i * (int)sizeof(short), data_f.format()); + } else { + cm_ptr_store((int*)sorted_value, i * (int)sizeof(short), data.format()); + } + } + if (i < n) { + auto pos = seq_u32 + i; + if constexpr(std::is_same::value) { + SIMD_IF_BEGIN (pos < n) { + data = cm_ptr_load((ushort*)src, pos * (uint)sizeof(half)); + u16_f16(data, data_f); + cm_ptr_store((half*)sorted_value, pos * (uint)sizeof(half), data_f); + } SIMD_IF_END; + } else { + SIMD_IF_BEGIN (pos < n) { + data = cm_ptr_load((ushort*)src, pos * (uint)sizeof(ushort)); + cm_ptr_store((ushort*)sorted_value, pos * (uint)sizeof(ushort), data); + } SIMD_IF_END; + } + } + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index c51421721048bd..2452b8673ef78a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -135,6 +135,32 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { } else { internal_buffers.emplace_back(16, indexes_dt); // 0: intermediate partition output internal_buffers.emplace_back(16, indexes_dt); // 1: softmax exp_sums + + // internal buffer for XAttention + auto out_shape = params.output_layouts[0].get_shape(); + const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; + const size_t q_len = out_shape[0]; + const uint M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const uint N = kv_len / STRIDE; + const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); + const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + + auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); + internal_buffers.emplace_back(count_kq_max_wg, ov::element::f16); // 2: kq_max_wg + + const size_t block_size = get_xattn_block_size(); + if (block_size > 1) { + OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); + const size_t sum_per_n_token_in_block = block_size / STRIDE; // FIXME + const uint sum_per_token_in_block = block_size / STRIDE; + const uint k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const uint k_block_pad = k_block_in_group * N_kq_groups; + auto count_kq_exp_partial_sum = static_cast(desc->heads_num * q_stride_pad * k_block_pad); + internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f16); // 3: kq_exp_partial_sum + + auto count_elements_mask = static_cast(desc->heads_num * (q_stride_pad / sum_per_n_token_in_block) * k_block_pad); + internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask + } } return internal_buffers; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 4d508ac3064d88..7ab603dc737fa3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -375,7 +375,7 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block_indices_begins args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins #if PA_SPARSE_BLOCK_SIZE > 1 - args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SPARSE_BLOCK_MASK}); // sparse_block_mask + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask #endif args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); @@ -607,4 +607,245 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da }}; } + +//----------------------------------------------------------------------------------------------------------------- +// Helpers of XAttention +//----------------------------------------------------------------------------------------------------------------- + + +//----------------------------------------------------------------------------------------------------------------- +// Base generator of XAttention +//----------------------------------------------------------------------------------------------------------------- +JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_impl_params& params) const { + auto jit = KernelGenerator::get_jit_constants(params); + jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); + + auto desc = params.typed_desc(); + + const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)) / STRIDE; + + jit.make("STRIDE", STRIDE); + jit.make("HQ", desc->heads_num); + jit.make("HK", desc->kv_heads_num); + jit.make("HEAD_SIZE", desc->k_head_size); + jit.make("SG_M", SG_M); + jit.make("SG_N", SG_N); + jit.make("BLOCK_SG_M", BLOCK_SG_M); + jit.make("BLOCK_SG_N", BLOCK_SG_N); + jit.make("BLOCK_SIZE", get_xattn_block_size()); + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + jit.add(make_jit_constant("SCALE_FACTOR", scale_factor)); + jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); + jit.make("USE_KQ", 1); + jit.make("IS_CAUSAL", 1); + jit.make("USE_INT8", 0); + jit.make("HEAD_SIZE_KEY", desc->k_head_size); + + return jit; +} + +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate gemm_qk generator +//----------------------------------------------------------------------------------------------------------------- +Arguments XAttentionEstimateGEMMQK::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // keys cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // kq_exp_partial_sum + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // M + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // N + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // K + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // query_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // slice_no + args.push_back({ArgumentDescriptor::Types::SCALAR, 5}); // slice + args.push_back({ArgumentDescriptor::Types::SCALAR, 6}); // q_start_strided + + return args; +} + +DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { + return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + const auto desc = params.typed_desc(); + + // XAttention estimate is following afer kvcache_update. + const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; + // const size_t kv_heads_num = desc->kv_heads_num; + const size_t heads_num = desc->heads_num; + const size_t head_size = desc->k_head_size; + + auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; + auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; + + if (DEBUG_ENABLED) { // Debug + std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " + << "key_layout: " << key_layout.to_string() << ", querry_layout: " << querry_layout.to_string() << std::endl; + std::cout << "\tkey_dims = ["; + for (auto& it : key_layout.get_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tkey_pads = ["; + for (auto& it : key_layout.get_padded_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\tquery_dims = ["; + for (auto& it : querry_layout.get_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + std::cout << "\ttquery_pads = ["; + for (auto& it : querry_layout.get_padded_dims()) { + std::cout << static_cast(it) << ", "; + } + std::cout << "]" << std::endl; + } + + auto out_shape = params.output_layouts[0].get_shape(); + const size_t q_len = out_shape[0]; + + const uint M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const uint N = kv_len / STRIDE; + const uint K = STRIDE * head_size; + auto get_simple_pitch = [](const layout& layout) { + size_t pitch = 1; + auto dims_padding = layout.get_padded_dims(); + for(size_t i = dims_padding.size() - 1; i > 0; --i) { + pitch = dims_padding[i]; + if(pitch > 1) { + break; + } + } + return pitch; + }; + const uint query_pitch = get_simple_pitch(querry_layout); + const uint slice_no = 0, slice = 0; + + const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); + const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + + auto& wgs = kd.params.workGroups; + wgs.global = {N_kq_groups * (q_stride_pad / BLOCK_WG_M) * SG_N, SG_M, heads_num}; + wgs.local = {SG_N, SG_M, 1}; + + const uint q_start_strided = N - M; + OPENVINO_ASSERT(N > M, "length of key cache must be greater or equal than query"); + + auto& scalars = kd.params.scalars; + std::vector scaler_value = {M, N, K, query_pitch, slice_no, slice, q_start_strided}; + scalars.resize(scaler_value.size()); + + if (DEBUG_ENABLED) { // Debug + size_t kv_len = get_kv_len(params, PagedAttentionStage::PREFILL); + size_t max_context_len = get_max_context_len(params); + size_t past_len = get_past_len(params, 0); + std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " + << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad << ", q_start_strided: " << q_start_strided << ", kv_len: " << kv_len + << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] + << ", " << wgs.global[2] << "]" + << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + } + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::UINT32; + scalars[i].v.u32 = static_cast(scaler_value[i]); + } + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate find_block generator +//----------------------------------------------------------------------------------------------------------------- +Arguments XAttentionEstimateFindBlock::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + + // inputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // kq_exp_partial_sum + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // block_mask + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_stride + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_stride_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // causal_start_index + args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // thresh + + return args; +} + +DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { + return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + const auto desc = params.typed_desc(); + // auto rtp = static_cast(rt_params); + + assert(rt_params != nullptr); + + const uint wg_k = BLOCK_WG_M; + const uint wg_q = BLOCK_WG_N; + const size_t block_size = get_xattn_block_size(); + OPENVINO_ASSERT(wg_k % block_size == 0, "wg_k should be multiple of block_size then there is no tails from block_size"); + OPENVINO_ASSERT(wg_q % block_size == 0, "wg_q should be multiple of block_size then there is no tails from block_size"); + + const size_t sum_per_n_token_in_block = block_size / STRIDE; + + const size_t batch = params.input_layouts[PagedAttentionInputIdx::QUERY].get_partial_shape()[0].get_length(); + const size_t heads_num = desc->heads_num; + // const size_t head_size = desc->k_head_size; + + auto out_shape = params.output_layouts[0].get_shape(); + const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; + const size_t q_len = out_shape[0]; + const uint M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const uint N = kv_len / STRIDE; + const uint q_stride = M; + const uint k_stride = N; + const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); + const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + + const uint sum_per_token_in_block = block_size / STRIDE; + const uint k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const uint k_block_pad = k_block_in_group * N_kq_groups; + + const uint q_block = ceil_div(q_stride, sum_per_n_token_in_block); + const uint k_block = ceil_div(k_stride, sum_per_n_token_in_block); + + wgs.global = {q_stride_pad / sum_per_n_token_in_block, heads_num, batch}; + wgs.local = {1, 1, 1}; + + auto& scalars = kd.params.scalars; + std::vector scaler_value = {q_stride, q_stride_pad, k_block_pad, k_block - q_block}; + scalars.resize(scaler_value.size() + 1); + + if (DEBUG_ENABLED) { // Debug + std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " + << "k_block: " << k_block << ", q_block: " << q_block + << "q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " + << wgs.global[1] << ", " << wgs.global[2] << "]" + << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + } + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::UINT32; + scalars[i].v.u32 = static_cast(scaler_value[i]); + } + scalars[scaler_value.size()].t = ScalarDescriptor::Types::FLOAT32; // the last is for thresh with f32 dtype + scalars[scaler_value.size()].v.f32 = static_cast(THRESH); + }}; +} + } // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index ffabc6c9cac128..cf9e359661b17f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -33,8 +33,16 @@ constexpr auto get_pa_build_options() { // BLOCK_SIZE can be 16/32/64/128/256 #define PA_KV_CACHE_BLOCK_SIZE 256 // sparse attention block size is set to 1 to disable sparse attention support in CM kernels -#define PA_SPARSE_BLOCK_SIZE 1 +#define PA_SPARSE_BLOCK_SIZE 128 +constexpr uint BLOCK_SG_M = 64; //32 +constexpr uint BLOCK_SG_N = 32; +constexpr uint SG_M = 4; +constexpr uint SG_N = 8; +constexpr uint BLOCK_WG_M = BLOCK_SG_M * SG_M; +constexpr uint BLOCK_WG_N = BLOCK_SG_N * SG_N; +constexpr int STRIDE = 16; +constexpr float THRESH = 0.9; enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; struct PagedAttentionRuntimeParams : public ImplRuntimeParams { @@ -45,6 +53,10 @@ struct PagedAttentionRuntimeParams : public ImplRuntimeParams { size_t paged_attention_aligned_seq_len; }; + +//----------------------------------------------------------------------------------------------------------------- +// Helpers of XAttention +//----------------------------------------------------------------------------------------------------------------- int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size); PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param); size_t get_max_context_len(const kernel_impl_params& params); @@ -52,6 +64,9 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); size_t get_partition_size(); size_t get_partition_num(const size_t kv_len); + +inline size_t get_xattn_block_size() { return PA_SPARSE_BLOCK_SIZE; } + class PagedAttentionGeneratorBase : public KernelGenerator { public: explicit PagedAttentionGeneratorBase(std::string_view kernel_name, std::string_view stage_suffix = "_cm") : KernelGenerator(kernel_name, stage_suffix) {} @@ -93,4 +108,31 @@ class PagedAttentionGeneratorSingleTokenFinalization : public PagedAttentionGene [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; }; +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate generators +//----------------------------------------------------------------------------------------------------------------- +class XAttentionEstimateGeneratorBase : public KernelGenerator { +public: + explicit XAttentionEstimateGeneratorBase(std::string_view kernel_name, std::string_view stage_suffix = "_cm") : KernelGenerator(kernel_name, stage_suffix) {} + [[nodiscard]] std::string get_build_options(const RuntimeParams& params) const override { + return KernelGenerator::get_build_options(params) + get_pa_build_options(); + } + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; +}; +class XAttentionEstimateGEMMQK : public XAttentionEstimateGeneratorBase { +public: + XAttentionEstimateGEMMQK() : XAttentionEstimateGeneratorBase("xattn_gemm_qk") {} + // [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class XAttentionEstimateFindBlock : public XAttentionEstimateGeneratorBase { +public: + XAttentionEstimateFindBlock() : XAttentionEstimateGeneratorBase("xattn_find_block") {} + // [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + } // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm new file mode 100644 index 00000000000000..7cecacffaf0c27 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "find_block.hpp" + +#ifndef ATTR +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +// _GENX_MAIN_ void find_block( +extern "C" _GENX_MAIN_ void KERNEL_NAME( + svmptr_t kq_max_wg ATTR, + svmptr_t kq_exp_partial_sum ATTR, + svmptr_t block_mask ATTR, + uint q_stride, uint q_stride_pad, uint k_block_pad, + uint causal_start_index, float thresh +#if DEBUG_ACC == 1 + , svmptr_t kq_sum ATTR +#endif +) { + // kq_max_wg: [b, hq, n_groups, q_stride_pad] + // kq_exp_partial_sum: [b, hq, q_stride_pad, k_block_pad] + // kq_sum: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] + // block_mask: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] + // [1, 32, 256], [1, 32, 64, 256], [1, 32, 256, 64 * 16], A_sum:[1, 32, 32, 64 * 16] + // global: [q_stride_pad/TOKEN_IN_BLOCK, hq, b] + const int TOKEN_IN_BLOCK = BLOCK_SIZE / STRIDE; + const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; + uint m = cm_group_id(0); + uint hq = cm_group_id(1); + uint b = cm_group_id(2); + kq_max_wg += (b * HQ + hq) * (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad * (uint)sizeof(half); + kq_exp_partial_sum += (b * HQ + hq) * q_stride_pad * k_block_pad * (uint)sizeof(half); +#if DEBUG_ACC == 1 + kq_sum += (b * HQ + hq) * q_stride_pad / TOKEN_IN_BLOCK * k_block_pad * (uint)sizeof(half); +#endif + block_mask += (b * HQ + hq) * q_stride_pad / TOKEN_IN_BLOCK * k_block_pad; + + const uint slm_size = 32 * 16 * sizeof(ushort); + cm_slm_init(slm_size); + auto slm = cm_slm_alloc(slm_size); + + find(slm, m, kq_max_wg, kq_exp_partial_sum, block_mask, q_stride, q_stride_pad, k_block_pad, thresh, causal_start_index +#if DEBUG_ACC == 1 + , kq_sum +#endif + ); +} + +} // NAMESPACE \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm new file mode 100644 index 00000000000000..b6175b73b0f5f1 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm @@ -0,0 +1,112 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "estimate.hpp" + +#define ABS(x) (x) < 0 ? -(x) : (x) + +CM_INLINE void get_mn(uint& id_wg_m, uint& id_wg_n, uint M, uint N, int slice_no, int slice, const int BLOCK_WG_M, const int BLOCK_WG_N) { + uint id_wg_mn = cm_group_id(0); + if (slice_no == 0) { + if (slice == 0) { + // loop M first, N is shared, total = N/256*M+N + uint WG_MN = (M + BLOCK_WG_M - 1) / BLOCK_WG_M; + id_wg_m = id_wg_mn % WG_MN; + id_wg_n = id_wg_mn / WG_MN; + } else { + // loop N first, M is shared, total = M/128*N+M + uint WG_MN = (N + BLOCK_WG_N - 1) / BLOCK_WG_N; + id_wg_n = id_wg_mn % WG_MN; + id_wg_m = id_wg_mn / WG_MN; + } + } else { + uint wg_x = slice > 0 ? N / BLOCK_WG_N : M / BLOCK_WG_M; + uint slice_no_abs = ABS(slice_no); + uint slice_abs = ABS(slice); + int id_wg_mn_in_reminder = (int)id_wg_mn - (int)(slice_no_abs * slice_abs * wg_x); + uint slice_idx; + // in [slice_no x slice] + if (id_wg_mn_in_reminder < 0) { + slice_idx = id_wg_mn / (slice_abs * wg_x); + uint rem_in_slice = id_wg_mn % (slice_abs * wg_x); + uint x = rem_in_slice % slice_abs; + uint y = rem_in_slice / slice_abs; + id_wg_m = slice > 0 ? x + slice_idx * slice_abs : y; + id_wg_n = slice < 0 ? x + slice_idx * slice_abs : y; + } else { + uint slice_rem = slice_abs + (slice_no > 0 ? 1 : -1); + slice_idx = id_wg_mn_in_reminder / (slice_rem * wg_x); + uint rem_in_slice = id_wg_mn_in_reminder % (slice_rem * wg_x); + uint x = rem_in_slice % slice_rem; + uint y = rem_in_slice / slice_rem; + id_wg_m = slice > 0 ? x + slice_idx * slice_rem + slice_no_abs * slice_abs : y; + id_wg_n = slice < 0 ? x + slice_idx * slice_rem + slice_no_abs * slice_abs : y; + } + } +} + +// _GENX_MAIN_ void gemm_qk +extern "C" _GENX_MAIN_ void KERNEL_NAME( + svmptr_t key_cache ATTR, + svmptr_t query ATTR, + svmptr_t block_indices ATTR, + svmptr_t block_indices_begins ATTR, + svmptr_t kq_max_wg ATTR, + svmptr_t kq_exp_partial_sum ATTR, + uint M, uint N, uint K, uint query_stride, int slice_no, int slice, uint q_start_strided) { + const uint BLOCK_WG_M = BLOCK_SG_M * SG_M; + const uint BLOCK_WG_N = BLOCK_SG_N * SG_N; + const uint size_slm_b = 0; + uint hq = cm_group_id(2); + uint hk = hq / (HQ / HK); + const uint slm_size = SG_N * BLOCK_WG_M * sizeof(half); + cm_slm_init(slm_size); + auto slm = cm_slm_alloc(slm_size); + + static_assert(HQ % HK == 0, "HQ must be multiple of HK"); + + uint id_wg_m, id_wg_n; + get_mn(id_wg_m, id_wg_n, M, N, slice_no, slice, BLOCK_WG_M, BLOCK_WG_N); + + // key cache: [block, HQ, KV_BLOCK_SIZE, HEAD_SIZE_KEY] +#if USE_INT8 + key_cache += hk * (KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); +#else + key_cache += hk * (KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(half)); +#endif + // query: [l_q, HQ * HEAD_SIZE] + query += hq * HEAD_SIZE * (uint)sizeof(half); + + // kq_max: [hq, m_pad] + // kq_max_wg: [hq, n_groups, m_pad] + // kq_exp_partial_sum: [hq, m_pad, n_groups*BLOCK_WG_M/(BLOCK_SIZE/STRIDE)] + uint m_pad = (M + BLOCK_WG_M - 1) / BLOCK_WG_M * BLOCK_WG_M; + uint n_groups = (N + BLOCK_WG_N - 1) / BLOCK_WG_N; + kq_max_wg += hq * n_groups * m_pad * (uint)sizeof(half); + + const uint sum_per_n_token_in_block = BLOCK_SIZE / STRIDE; + const uint n_after_sum_in_group = BLOCK_WG_N / sum_per_n_token_in_block; + const uint n_after_sum_pad = n_after_sum_in_group * n_groups; + kq_exp_partial_sum += hq * n_after_sum_pad * m_pad * (uint)sizeof(half); + +#define CONCAT_IMPL(a, b) gemm_qk_ ##a ##x ##b ##_xe2 +#define CONCAT(x, y) CONCAT_IMPL(x, y) +#define FUNC CONCAT(BLOCK_SG_M, BLOCK_SG_N) + FUNC(id_wg_m, id_wg_n, hq, slm, key_cache, query, block_indices, block_indices_begins, kq_max_wg, kq_exp_partial_sum, M, N, K, query_stride, q_start_strided); +} + +} // NAMESPACE \ No newline at end of file From ac882abb7f96f9e3ca1431745a357a553bf777ac Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 5 Sep 2025 09:04:51 +0800 Subject: [PATCH 09/96] qwen2.5-1.5b 4k trunk works with xatten. --- .../src/graph/impls/cm/include/find_block.hpp | 1 + .../src/graph/impls/cm/paged_attention.cpp | 16 +++++ .../graph/impls/cm/paged_attention_gen.cpp | 59 +++++++++++++++---- .../graph/impls/cm/paged_attention_gen.hpp | 2 +- .../src/graph/impls/cm/xattn_gemm_qk.cm | 12 ++++ .../src/graph/include/paged_attention_inst.h | 4 ++ src/plugins/intel_gpu/src/runtime/memory.cpp | 6 +- 7 files changed, 83 insertions(+), 17 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp index b26e97b2478b1d..ae16f6b48f12b5 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -25,6 +25,7 @@ #include "sort.hpp" +#define MYMIN(x, y) ((x) < (y) ? (x) : (y)) // kq_max_wg: [b, hq, n_groups, q_stride_pad] // kq_exp_partial_sum: [b, hq, q_stride_pad, k_block_pad] diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 2452b8673ef78a..52ea5b929f4e26 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -30,6 +30,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { Stage::Ptr pa_single_token = make_stage(); Stage::Ptr pa_single_token_finalization = make_stage(); Stage::Ptr pa_multi_token = make_stage(); + Stage::Ptr xattn_estimate_gemmqk = make_stage(); + Stage::Ptr xattn_estimate_find_block = make_stage(); PagedAttentionCmImpl(): PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) { m_rt_params = std::make_unique(); @@ -42,6 +44,10 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { add_stage(pa_single_token, params); add_stage(pa_single_token_finalization, params); add_stage(pa_multi_token, params); +#if PA_SPARSE_BLOCK_SIZE > 1 + add_stage(xattn_estimate_gemmqk, params); + add_stage(xattn_estimate_find_block, params); +#endif } void update_rt_params(const primitive_inst& instance) override { @@ -83,6 +89,16 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { res_event = {execute_stage(res_event, instance, kv_cache_update)}; if (rt_params->stage == PagedAttentionStage::PREFILL || rt_params->stage == PagedAttentionStage::MIXED) { +#if PA_SPARSE_BLOCK_SIZE > 1 + cldnn::stream& stream = instance.get_network().get_stream(); + stream.finish(); + res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; + stream.finish(); + std::cout << "finish xattn_estimate_gemmqk!\n"; + res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; + stream.finish(); + std::cout << "finish xattn_estimate_find_block!\n"; +#endif res_event = {execute_stage(res_event, instance, pa_multi_token)}; } else if (rt_params->stage == PagedAttentionStage::GENERATE) { res_event = {execute_stage(res_event, instance, pa_single_token)}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 7ab603dc737fa3..2dc2726bfe1dc7 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -186,6 +186,26 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx) { return paged_attention_past_len; } +inline void dump_block_indices_begins(const kernel_impl_params& params) { + const auto& input_mem = params.memory_deps; + const auto mem = input_mem.at(PagedAttentionInputIdx::BLOCK_INDICES_BEGINS); + mem_lock mem_lock(mem, *params.strm); + std::cout << "============ dump BLOCK_INDICES_BEGINS ["; + for (size_t i = 0; i < mem->count(); i++) + std::cout << mem_lock[i] << ", "; + std::cout << "]" << std::endl; +} + +inline void dump_block_indices(const kernel_impl_params& params) { + const auto& input_mem = params.memory_deps; + const auto mem = input_mem.at(PagedAttentionInputIdx::BLOCK_INDICES); + mem_lock mem_lock(mem, *params.strm); + std::cout << "============ dump BLOCK_INDICES ["; + for (size_t i = 0; i < mem->count(); i++) + std::cout << mem_lock[i] << ", "; + std::cout << "]" << std::endl; +} + PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) { const auto& query_shape = impl_param.get_input_layout(PagedAttentionInputIdx::QUERY).get_partial_shape(); const auto& past_lens_shape = impl_param.get_input_layout(PagedAttentionInputIdx::PAST_LENS).get_partial_shape(); @@ -634,13 +654,18 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp jit.make("BLOCK_SG_N", BLOCK_SG_N); jit.make("BLOCK_SIZE", get_xattn_block_size()); jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); - jit.add(make_jit_constant("SCALE_FACTOR", scale_factor)); + jit.add(make_jit_constant("INV_S", scale_factor)); jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); jit.make("USE_KQ", 1); - jit.make("IS_CAUSAL", 1); + jit.make("IS_CAUSAL", 0); jit.make("USE_INT8", 0); jit.make("HEAD_SIZE_KEY", desc->k_head_size); + // for (auto& it : jit) { + // std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl; + // } + // std::cout << std::endl; + return jit; } @@ -685,7 +710,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; - if (DEBUG_ENABLED) { // Debug + if (1 || DEBUG_ENABLED) { // Debug std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " << "key_layout: " << key_layout.to_string() << ", querry_layout: " << querry_layout.to_string() << std::endl; std::cout << "\tkey_dims = ["; @@ -703,7 +728,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { std::cout << static_cast(it) << ", "; } std::cout << "]" << std::endl; - std::cout << "\ttquery_pads = ["; + std::cout << "\tquery_pads = ["; for (auto& it : querry_layout.get_padded_dims()) { std::cout << static_cast(it) << ", "; } @@ -727,7 +752,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { } return pitch; }; - const uint query_pitch = get_simple_pitch(querry_layout); + const uint query_pitch = get_simple_pitch(querry_layout) * STRIDE; const uint slice_no = 0, slice = 0; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); @@ -738,26 +763,34 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { wgs.local = {SG_N, SG_M, 1}; const uint q_start_strided = N - M; - OPENVINO_ASSERT(N > M, "length of key cache must be greater or equal than query"); + OPENVINO_ASSERT(N >= M, "length of key cache must be greater or equal than query"); auto& scalars = kd.params.scalars; std::vector scaler_value = {M, N, K, query_pitch, slice_no, slice, q_start_strided}; scalars.resize(scaler_value.size()); - if (DEBUG_ENABLED) { // Debug + if (1 || DEBUG_ENABLED) { // Debug size_t kv_len = get_kv_len(params, PagedAttentionStage::PREFILL); size_t max_context_len = get_max_context_len(params); size_t past_len = get_past_len(params, 0); std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " - << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad << ", q_start_strided: " << q_start_strided << ", kv_len: " << kv_len + << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad << ", scaler_value: " << PartialShape(scaler_value) << ", kv_len: " << kv_len << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + + dump_block_indices_begins(params); + dump_block_indices(params); } for (size_t i = 0; i < scaler_value.size(); ++i) { - scalars[i].t = ScalarDescriptor::Types::UINT32; - scalars[i].v.u32 = static_cast(scaler_value[i]); + if (i == 4 || i == 5) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } else { + scalars[i].t = ScalarDescriptor::Types::UINT32; + scalars[i].v.u32 = static_cast(scaler_value[i]); + } } }}; } @@ -803,7 +836,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { const size_t sum_per_n_token_in_block = block_size / STRIDE; - const size_t batch = params.input_layouts[PagedAttentionInputIdx::QUERY].get_partial_shape()[0].get_length(); + // const size_t batch = params.input_layouts[PagedAttentionInputIdx::QUERY].get_partial_shape()[0].get_length(); const size_t heads_num = desc->heads_num; // const size_t head_size = desc->k_head_size; @@ -824,14 +857,14 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { const uint q_block = ceil_div(q_stride, sum_per_n_token_in_block); const uint k_block = ceil_div(k_stride, sum_per_n_token_in_block); - wgs.global = {q_stride_pad / sum_per_n_token_in_block, heads_num, batch}; + wgs.global = {q_stride_pad / sum_per_n_token_in_block, heads_num, 1}; wgs.local = {1, 1, 1}; auto& scalars = kd.params.scalars; std::vector scaler_value = {q_stride, q_stride_pad, k_block_pad, k_block - q_block}; scalars.resize(scaler_value.size() + 1); - if (DEBUG_ENABLED) { // Debug + if (1 || DEBUG_ENABLED) { // Debug std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " << "k_block: " << k_block << ", q_block: " << q_block << "q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index cf9e359661b17f..03d8e6369ac673 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -35,7 +35,7 @@ constexpr auto get_pa_build_options() { // sparse attention block size is set to 1 to disable sparse attention support in CM kernels #define PA_SPARSE_BLOCK_SIZE 128 -constexpr uint BLOCK_SG_M = 64; //32 +constexpr uint BLOCK_SG_M = 64; constexpr uint BLOCK_SG_N = 32; constexpr uint SG_M = 4; constexpr uint SG_N = 8; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm index b6175b73b0f5f1..f1760bcd0e2f0a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm @@ -77,11 +77,23 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( cm_slm_init(slm_size); auto slm = cm_slm_alloc(slm_size); + // printf("=============== gemm_qk ===================="); static_assert(HQ % HK == 0, "HQ must be multiple of HK"); uint id_wg_m, id_wg_n; get_mn(id_wg_m, id_wg_n, M, N, slice_no, slice, BLOCK_WG_M, BLOCK_WG_N); + // auto wg_id_N = cm_group_id(0); + // auto wg_lid_N = cm_local_id(0); + // auto wg_id_M = cm_group_id(1); + // auto wg_lid_M = cm_local_id(1); + // printf("=============== wgN:%d.%d wgM:%d.%d hq %d: id_wg_m %d, id_wg_n %d, %p, %p, %p, %p; test: %p,%p, M %d, N %d, K %d, %d, %d; %d, %d\n", + // wg_id_N, wg_lid_N, wg_id_M, wg_lid_M, hq, + // id_wg_m, id_wg_n, query, key_cache, kq_max_wg, kq_exp_partial_sum, + // key_cache+sizeof(half), query+sizeof(half), + // M, N, K, query_stride, q_start_strided, + // KV_BLOCK_SIZE, HEAD_SIZE_KEY); + // key cache: [block, HQ, KV_BLOCK_SIZE, HEAD_SIZE_KEY] #if USE_INT8 key_cache += hk * (KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index b635c45d08a5f5..a832ee6bbf48d4 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -24,6 +24,10 @@ struct typed_program_node : public typed_program_node_basehas_score_aggregation) input_ports.insert(PagedAttentionInputIdx::SCORE_AGGREGATION); diff --git a/src/plugins/intel_gpu/src/runtime/memory.cpp b/src/plugins/intel_gpu/src/runtime/memory.cpp index f69a3124da7d6d..c43526fd6be184 100644 --- a/src/plugins/intel_gpu/src/runtime/memory.cpp +++ b/src/plugins/intel_gpu/src/runtime/memory.cpp @@ -35,9 +35,9 @@ MemoryTracker::~MemoryTracker() { try { m_engine->subtract_memory_used(m_buffer_size, m_alloc_type); } catch (...) {} - GPU_DEBUG_TRACE_DETAIL << "Free " << m_buffer_size << " bytes of " << m_alloc_type << " allocation type ptr = " << m_buffer_ptr - << " (current=" << m_engine->get_used_device_memory(m_alloc_type) << ";" - << " max=" << m_engine->get_max_used_device_memory(m_alloc_type) << ")" << std::endl; + // GPU_DEBUG_TRACE_DETAIL << "Free " << m_buffer_size << " bytes of " << m_alloc_type << " allocation type ptr = " << m_buffer_ptr + // << " (current=" << m_engine->get_used_device_memory(m_alloc_type) << ";" + // << " max=" << m_engine->get_max_used_device_memory(m_alloc_type) << ")" << std::endl; } } From 0621e4b30e979c1aac170e20bb3bb20521d7301a Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 5 Sep 2025 11:30:55 +0800 Subject: [PATCH 10/96] 4k aligned works. --- .../src/graph/impls/cm/include/estimate.hpp | 117 +++++++++++------- .../src/graph/impls/cm/include/find_block.hpp | 67 ++++++---- .../src/graph/impls/cm/paged_attention.cpp | 8 +- .../graph/impls/cm/paged_attention_gen.cpp | 17 ++- .../src/graph/impls/cm/xattn_find_block.cm | 10 +- .../src/graph/impls/cm/xattn_gemm_qk.cm | 32 +++-- 6 files changed, 147 insertions(+), 104 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp index 879c9bda50359c..ad8d2696356cb1 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp @@ -729,7 +729,7 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { uint M_aligned = (M + BLOCK_WG_M - 1) / BLOCK_WG_M * BLOCK_WG_M; uint K_block_pad = N_block * (BLOCK_WG_N / (BLOCK_SIZE / STRIDE)); const uint block_size_div_stride = BLOCK_SIZE / STRIDE; - constexpr half log2e = 1.4426950408889634f; + constexpr SOFTMAX_TYPE log2e = 1.4426950408889634f; //static_assert(BLOCK_SG_M / block_size_div_stride == 8, "BLOCK_SG_M / block_size_div_stride should be 8"); #if IS_CAUSAL == 1 @@ -737,24 +737,24 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { // fill -inf -> max in group, 0 -> exp_sum to make compensation work { // current max -> mem - vector max_m = -60000; + vector max_m = -60000; // kq_max_wg: [b, hq, N/BLOCK_WG_N, M_aligned] - uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(half); + uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(SOFTMAX_TYPE); cm_ptr_store((int*)kq_max_wg, offset, max_m.format()); } { // store - matrix sum_t = 0; - lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(half) - 1), (uint)(K_block_pad * sizeof(half) - 1), + matrix sum_t = 0; + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), (int)((id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) / block_size_div_stride), (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; - cm_store(desc_c, sum_t.format()); - cm_store(desc_c, sum_t.format()); - cm_store(desc_c, sum_t.format()); - cm_store(desc_c, sum_t.format()); - cm_store(desc_c, sum_t.format()); - cm_store(desc_c, sum_t.format()); - cm_store(desc_c, sum_t.format()); - cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); } return; @@ -776,17 +776,21 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { // M[0:16]xK[0:32] uint block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * STRIDE / KV_BLOCK_SIZE; + uint max_block_idx = (uint)(N * STRIDE + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE - 1; + block_idx = MYMIN(block_idx, max_block_idx); #if USE_INT8 uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); lsc::block_2d_desc desc_b0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), 0, 0 }; uint scale_offset0 = offset + KV_BLOCK_SIZE * HEAD_SIZE; - offset = block_indices_p[block_idx + 1] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); + block_idx = MYMIN(block_idx + 1, max_block_idx); + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); lsc::block_2d_desc desc_b1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), 0, 0 }; uint scale_offset1 = offset + KV_BLOCK_SIZE * HEAD_SIZE; // prefetch B block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + block_idx = MYMIN(block_idx, max_block_idx); offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); static_assert(BLOCK_WG_N / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); lsc::block_2d_desc desc_prefetch_b{ key_cache + offset, BLOCK_WG_N / SG_MN - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), @@ -801,15 +805,20 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); lsc::block_2d_desc desc_b0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), 0, 0 }; - offset = block_indices_p[block_idx + 1] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + // printf("===============0 lid:%d.%d, block_idx=%d, offset=%u\n", id_sg_n, id_sg_m, block_idx, offset); + block_idx = MYMIN(block_idx + 1, max_block_idx); + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); lsc::block_2d_desc desc_b1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), 0, 0 }; + // printf("===============1 lid:%d.%d, block_idx=%d, offset=%u\n", id_sg_n, id_sg_m, block_idx, offset); // prefetch B block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + block_idx = MYMIN(block_idx, max_block_idx); offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); static_assert(BLOCK_WG_N / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); lsc::block_2d_desc desc_prefetch_b{ key_cache + offset, BLOCK_WG_N / SG_MN - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), 0, 0 }; + // printf("===============2 lid:%d.%d, block_idx=%d, offset=%u\n", id_sg_n, id_sg_m, block_idx, offset); // 0~2 M[:]xK[0:16] 2~4 K[16:32] --> 32 * 2 regs matrix b0, b1; #endif @@ -823,10 +832,24 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { // load b: N[0:16]xK[0:16] #if USE_INT8 - scales_block[0].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset0); - zps_block[0].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset0 + KV_BLOCK_SIZE * (uint)sizeof(half)); - scales_block[1].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset1); - zps_block[1].format() = cm_ptr_load((uint64_t*)key_cache, scale_offset1 + KV_BLOCK_SIZE * (uint)sizeof(half)); + { + lsc::block_2d_desc desc_scale{ key_cache + scale_offset0, 16 * 2 - 1, (uint)(16 * sizeof(half) - 1), (uint)(16 * sizeof(half) - 1), + 0, 0 }; + matrix tmp_scale, tmp_zp; + cm_load(tmp_scale.format(), desc_scale); + cm_load(tmp_zp.format(), desc_scale); + scales_block[0].format().select<8, 2, 16, 1>(0) = tmp_scale.format().select<8, 1, 16, 2>(0, 0); + scales_block[0].format().select<8, 2, 16, 1>(1) = tmp_scale.format().select<8, 1, 16, 2>(0, 1); + zps_block[0].format().select<8, 2, 16, 1>(0) = tmp_zp.format().select<8, 1, 16, 2>(0, 0); + zps_block[0].format().select<8, 2, 16, 1>(1) = tmp_zp.format().select<8, 1, 16, 2>(0, 1); + desc_scale.set_base(key_cache + scale_offset1); + cm_load(tmp_scale.format(), desc_scale); + cm_load(tmp_zp.format(), desc_scale); + scales_block[1].format().select<8, 2, 16, 1>(0) = tmp_scale.format().select<8, 1, 16, 2>(0, 0); + scales_block[1].format().select<8, 2, 16, 1>(1) = tmp_scale.format().select<8, 1, 16, 2>(0, 1); + zps_block[1].format().select<8, 2, 16, 1>(0) = tmp_zp.format().select<8, 1, 16, 2>(0, 0); + zps_block[1].format().select<8, 2, 16, 1>(1) = tmp_zp.format().select<8, 1, 16, 2>(0, 1); + } cm_load(b0_up_s8.format(), desc_b0); cm_load(b0_down_s8.format(), desc_b1); @@ -873,16 +896,16 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { for (uint s = 0; s < STRIDE; s++) { #if USE_INT8 - auto tmp = scales_block[0].select<16, 16>(s); + auto tmp = scales_block[0].select<16, 1>(s * 16); scales[0].select<16, 2>(0) = tmp; scales[0].select<16, 2>(1) = scales[0].select<16, 2>(0); - tmp = scales_block[1].select<16, 16>(s); + tmp = scales_block[1].select<16, 1>(s * 16); scales[1].select<16, 2>(0) = tmp; scales[1].select<16, 2>(1) = scales[1].select<16, 2>(0); - tmp = zps_block[0].select<16, 16>(s); + tmp = zps_block[0].select<16, 1>(s * 16); zps[0].select<16, 2>(0) = tmp; zps[0].select<16, 2>(1) = zps[0].select<16, 2>(0); - tmp = zps_block[1].select<16, 16>(s); + tmp = zps_block[1].select<16, 1>(s * 16); zps[1].select<16, 2>(0) = tmp; zps[1].select<16, 2>(1) = zps[1].select<16, 2>(0); #endif @@ -985,7 +1008,7 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { cm_sbarrier(0); - matrix acc_half; + matrix acc_half; #pragma unroll for (uint reg_m = 0; reg_m < REG_M; reg_m++) { #pragma unroll @@ -1003,7 +1026,7 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { n_start = MYMIN(n_start, N); int n_end = MYMIN(n_start + BLOCK_SG_N, N); int valid_n = n_end - n_start; - matrix sum_t; + matrix sum_t; vector seq_m; cmtl::cm_vector_assign(seq_m.select_all(), 0, 1); vector_ref seq = seq_m.select(); @@ -1019,37 +1042,37 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { #pragma unroll for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { SIMD_IF_BEGIN (n_pos > m_start + reg_m) { - acc_half.row(reg_m) = half{-60000}; + acc_half.row(reg_m) = SOFTMAX_TYPE{-60000}; } SIMD_IF_END; } } #else bool skip_mask = true; #endif - vector max_m; + vector max_m; if (valid_n != BLOCK_SG_N) { #pragma unroll for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { - acc_half.row(reg_m).merge(half{-60000}, n_pos >= N); + acc_half.row(reg_m).merge(SOFTMAX_TYPE{-60000}, n_pos >= N); } } - max_m.select<32, 1>() = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>()).format(); - max_m.select<32, 1>(32) = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>(32)).format(); + max_m.select<32, 1>() = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>()).format(); + max_m.select<32, 1>(32) = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>(32)).format(); { - uint slm_offset = (id_sg_n * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(half); + uint slm_offset = (id_sg_n * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(SOFTMAX_TYPE); // current max -> slm cm_slm_block_write(slm, slm_offset, max_m.format()); cm_slm_fence(CM_LOCAL_BARRIER); cm_barrier(); // max inside wg - cm_slm_block_read(slm, id_sg_m * BLOCK_SG_M * (uint)sizeof(half), max_m.format()); - vector tmp; + cm_slm_block_read(slm, id_sg_m * BLOCK_SG_M * (uint)sizeof(SOFTMAX_TYPE), max_m.format()); + vector tmp; #pragma unroll for (uint i = 1; i < SG_N; i++) { - slm_offset = (i * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(half); + slm_offset = (i * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(SOFTMAX_TYPE); cm_slm_block_read(slm, slm_offset, tmp.format()); - max_m = cm_max(max_m, tmp); + max_m = cm_max(max_m, tmp); } // max across wg // kq_max: [b, hq, M_aligned] @@ -1058,7 +1081,7 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { // current max -> mem // kq_max_wg: [b, hq, N/BLOCK_WG_N, M_aligned] - uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(half); + uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(SOFTMAX_TYPE); cm_ptr_store((int*)kq_max_wg, offset, max_m.format()); } { @@ -1089,19 +1112,19 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { } SIMD_IF_END; } } - sum_t.select<32, 1, 4, 1>( 0).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>( 0)).format(); - sum_t.select<32, 1, 4, 1>(32).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>(32)).format(); + sum_t.select<32, 1, 4, 1>( 0).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>( 0)).format(); + sum_t.select<32, 1, 4, 1>(32).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>(32)).format(); } // store - lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(half) - 1), (uint)(K_block_pad * sizeof(half) - 1), + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), (int)((id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) / block_size_div_stride), (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; - cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 0).format()); - cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 8).format()); - cm_store(desc_c, sum_t.select<8, 1, 4, 1>(16).format()); - cm_store(desc_c, sum_t.select<8, 1, 4, 1>(24).format()); - cm_store(desc_c, sum_t.select<8, 1, 4, 1>(32).format()); - cm_store(desc_c, sum_t.select<8, 1, 4, 1>(40).format()); - cm_store(desc_c, sum_t.select<8, 1, 4, 1>(48).format()); - cm_store(desc_c, sum_t.select<8, 1, 4, 1>(56).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 0).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 8).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(16).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(24).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(32).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(40).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(48).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(56).format()); } #endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp index ae16f6b48f12b5..5af1329a44fd6d 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -27,6 +27,12 @@ #define MYMIN(x, y) ((x) < (y) ? (x) : (y)) +#define MYCONCAT(x, y) x ## y +#define IS_float 1 +#define IS_half 2 +#define CUR_TYPE_(a) MYCONCAT(IS_, a) +#define CUR_TYPE CUR_TYPE_(SOFTMAX_TYPE) + // kq_max_wg: [b, hq, n_groups, q_stride_pad] // kq_exp_partial_sum: [b, hq, q_stride_pad, k_block_pad] // kq_sum: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] @@ -50,43 +56,53 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p const int TOKEN_IN_BLOCK = (BLOCK_SIZE / STRIDE); int m = m_block * TOKEN_IN_BLOCK; - vector max_m; + vector max_m; const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; - kq_exp_partial_sum += m * k_block_pad * (int)sizeof(half); - kq_max_wg += m * (int)sizeof(half); - constexpr half log2e = 1.4426950408889634f; + kq_exp_partial_sum += m * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + kq_max_wg += m * (int)sizeof(SOFTMAX_TYPE); + constexpr SOFTMAX_TYPE log2e = 1.4426950408889634f; matrix sum_m = 0; - matrix data; + matrix data; int m_start = MYMIN(m, q_stride); int m_end = MYMIN(m_start + TOKEN_SHARE_MAX, q_stride); int valid_m = m_end - m_start; if (valid_m == 0) return; - lsc::block_2d_desc desc_sum{ kq_exp_partial_sum, (uint)valid_m - 1, (uint)(k_block_pad * sizeof(half) - 1), (uint)(k_block_pad * sizeof(half) - 1), + lsc::block_2d_desc desc_sum{ kq_exp_partial_sum, (uint)valid_m - 1, (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), 0, 0 }; { // find max: (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad - max_m = half{-60000}; + max_m = SOFTMAX_TYPE{-60000}; for (int idx = 0; idx < k_block_pad / TOKEN_SHARE_MAX; idx++) { - vector max_m_in_group; - max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(half)); - max_m = cm_max(max_m, max_m_in_group); + vector max_m_in_group; + max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(SOFTMAX_TYPE)); + max_m = cm_max(max_m, max_m_in_group); } } // compensation: val*exp(local - global) desc_sum.set_block_x(0); for (int j = 0, idx = 0; j < k_block_pad; j += TOKEN_SHARE_MAX, idx++) { - vector max_m_in_group; - max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(half)); - cm_load(data.format(), desc_sum); + vector max_m_in_group; + max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(SOFTMAX_TYPE)); +#if CUR_TYPE == IS_float + cm_load(data.select(0, 0).format(), desc_sum); + cm_load(data.select(0, TOKEN_SHARE_MAX / 2).format(), desc_sum); +#else + cm_load(data.format(), desc_sum); +#endif for (int i = 0; i < TOKEN_IN_BLOCK; i++) { if (i < valid_m) { data.row(i) *= cm_exp((max_m_in_group[i] - max_m[i]) * log2e); sum_m.row(i) += data.row(i); } } - cm_store(desc_sum, data.format()); +#if CUR_TYPE == IS_float + cm_store(desc_sum, data.select(0, 0).format()); + cm_store(desc_sum, data.select(0, TOKEN_SHARE_MAX / 2).format()); +#else + cm_store(desc_sum, data.format()); +#endif desc_sum.set_block_x(desc_sum.get_block_x() + TOKEN_SHARE_MAX); } @@ -105,16 +121,23 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p kq_sum += m_block * k_block_pad * (int)sizeof(half); #endif for (int j = 0; j < k_block_pad; j += TOKEN_SHARE_MAX) { - cm_load(data.format(), desc_sum); +#if CUR_TYPE == IS_float + cm_load(data.select(0, 0).format(), desc_sum); + cm_load(data.select(0, TOKEN_SHARE_MAX / 2).format(), desc_sum); +#else + cm_load(data.format(), desc_sum); +#endif data.row(0) *= inv_sum_v[0]; for (int i = 1; i < TOKEN_IN_BLOCK; i++) { data.row(0) += data.row(i) * inv_sum_v[i]; } desc_sum.set_block_x(desc_sum.get_block_x() + TOKEN_SHARE_MAX); sum_m_after_add += data.row(0); - cm_ptr_store((int*)kq_exp_partial_sum, j * (int)sizeof(half), data.row(0).format()); + // the sum type is always half in the reference code of the paper: https://github.com/mit-han-lab/x-attention/blob/fb2ac200a23d20568f7d166ddb5ee247926d2b2b/xattn/src/kernels.py#L248 + vector data_half = data.row(0); + cm_ptr_store((int*)kq_exp_partial_sum, j * (int)sizeof(half), data_half.format()); #if DEBUG_ACC == 1 - cm_ptr_store((int*)kq_sum, j * (int)sizeof(half), data.row(0).format()); + cm_ptr_store((int*)kq_sum, j * (int)sizeof(half), data_half.format()); #endif } auto thresh_act = cm_sum(sum_m_after_add) * thresh; @@ -126,11 +149,11 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p // line 5: sorted tmp // line 6: accumalative score block_mask += m_block * k_block_pad; - auto score = kq_exp_partial_sum + 0 * k_block_pad * (int)sizeof(half); - auto sorted_value = kq_exp_partial_sum + 1 * k_block_pad * (int)sizeof(half); - auto sorted_index = kq_exp_partial_sum + 3 * k_block_pad * (int)sizeof(half); - auto sorted_tmp = kq_exp_partial_sum + 5 * k_block_pad * (int)sizeof(half); - auto acc_score = kq_exp_partial_sum + 6 * k_block_pad * (int)sizeof(half); + auto score = kq_exp_partial_sum + 0 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto sorted_value = kq_exp_partial_sum + 1 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto sorted_index = kq_exp_partial_sum + 3 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto sorted_tmp = kq_exp_partial_sum + 5 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto acc_score = kq_exp_partial_sum + 6 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); #if IS_CAUSAL == 1 auto score_p = (half*)score; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 52ea5b929f4e26..ce9d5b601b5d08 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -94,10 +94,10 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { stream.finish(); res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; stream.finish(); - std::cout << "finish xattn_estimate_gemmqk!\n"; + // std::cout << "finish xattn_estimate_gemmqk!\n"; res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; stream.finish(); - std::cout << "finish xattn_estimate_find_block!\n"; + // std::cout << "finish xattn_estimate_find_block!\n"; #endif res_event = {execute_stage(res_event, instance, pa_multi_token)}; } else if (rt_params->stage == PagedAttentionStage::GENERATE) { @@ -162,7 +162,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); - internal_buffers.emplace_back(count_kq_max_wg, ov::element::f16); // 2: kq_max_wg + internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg const size_t block_size = get_xattn_block_size(); if (block_size > 1) { @@ -172,7 +172,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const uint k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; const uint k_block_pad = k_block_in_group * N_kq_groups; auto count_kq_exp_partial_sum = static_cast(desc->heads_num * q_stride_pad * k_block_pad); - internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f16); // 3: kq_exp_partial_sum + internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum auto count_elements_mask = static_cast(desc->heads_num * (q_stride_pad / sum_per_n_token_in_block) * k_block_pad); internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 2dc2726bfe1dc7..3b2bf51c52c75f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -656,10 +656,12 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); jit.add(make_jit_constant("INV_S", scale_factor)); jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); - jit.make("USE_KQ", 1); - jit.make("IS_CAUSAL", 0); + //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... + jit.make("WALK_HQ", desc->heads_num != desc->kv_heads_num ? 2 : 1); + jit.make("IS_CAUSAL", 1); jit.make("USE_INT8", 0); jit.make("HEAD_SIZE_KEY", desc->k_head_size); + jit.make("SOFTMAX_TYPE", "float"); // for (auto& it : jit) { // std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl; @@ -710,7 +712,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; - if (1 || DEBUG_ENABLED) { // Debug + if (DEBUG_ENABLED) { // Debug std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " << "key_layout: " << key_layout.to_string() << ", querry_layout: " << querry_layout.to_string() << std::endl; std::cout << "\tkey_dims = ["; @@ -758,8 +760,11 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... + const size_t WALK_HQ = desc->heads_num != desc->kv_heads_num ? 2 : 1; + auto& wgs = kd.params.workGroups; - wgs.global = {N_kq_groups * (q_stride_pad / BLOCK_WG_M) * SG_N, SG_M, heads_num}; + wgs.global = {N_kq_groups * (q_stride_pad / BLOCK_WG_M) * SG_N * WALK_HQ, SG_M, heads_num / WALK_HQ}; wgs.local = {SG_N, SG_M, 1}; const uint q_start_strided = N - M; @@ -769,7 +774,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { std::vector scaler_value = {M, N, K, query_pitch, slice_no, slice, q_start_strided}; scalars.resize(scaler_value.size()); - if (1 || DEBUG_ENABLED) { // Debug + if (DEBUG_ENABLED) { // Debug size_t kv_len = get_kv_len(params, PagedAttentionStage::PREFILL); size_t max_context_len = get_max_context_len(params); size_t past_len = get_past_len(params, 0); @@ -864,7 +869,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { std::vector scaler_value = {q_stride, q_stride_pad, k_block_pad, k_block - q_block}; scalars.resize(scaler_value.size() + 1); - if (1 || DEBUG_ENABLED) { // Debug + if (DEBUG_ENABLED) { // Debug std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " << "k_block: " << k_block << ", q_block: " << q_block << "q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm index 7cecacffaf0c27..2201d2b088eebb 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm @@ -24,11 +24,7 @@ namespace KERNEL_NAME { // _GENX_MAIN_ void find_block( extern "C" _GENX_MAIN_ void KERNEL_NAME( - svmptr_t kq_max_wg ATTR, - svmptr_t kq_exp_partial_sum ATTR, - svmptr_t block_mask ATTR, - uint q_stride, uint q_stride_pad, uint k_block_pad, - uint causal_start_index, float thresh + svmptr_t kq_max_wg ATTR, svmptr_t kq_exp_partial_sum ATTR, svmptr_t block_mask ATTR, uint q_stride, uint q_stride_pad, uint k_block_pad, uint causal_start_index, float thresh #if DEBUG_ACC == 1 , svmptr_t kq_sum ATTR #endif @@ -44,8 +40,8 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint m = cm_group_id(0); uint hq = cm_group_id(1); uint b = cm_group_id(2); - kq_max_wg += (b * HQ + hq) * (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad * (uint)sizeof(half); - kq_exp_partial_sum += (b * HQ + hq) * q_stride_pad * k_block_pad * (uint)sizeof(half); + kq_max_wg += (b * HQ + hq) * (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad * (uint)sizeof(SOFTMAX_TYPE); + kq_exp_partial_sum += (b * HQ + hq) * q_stride_pad * k_block_pad * (uint)sizeof(SOFTMAX_TYPE); #if DEBUG_ACC == 1 kq_sum += (b * HQ + hq) * q_stride_pad / TOKEN_IN_BLOCK * k_block_pad * (uint)sizeof(half); #endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm index f1760bcd0e2f0a..4e1e305944983f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm @@ -20,7 +20,7 @@ namespace KERNEL_NAME { #define ABS(x) (x) < 0 ? -(x) : (x) CM_INLINE void get_mn(uint& id_wg_m, uint& id_wg_n, uint M, uint N, int slice_no, int slice, const int BLOCK_WG_M, const int BLOCK_WG_N) { - uint id_wg_mn = cm_group_id(0); + uint id_wg_mn = cm_group_id(0) / WALK_HQ; if (slice_no == 0) { if (slice == 0) { // loop M first, N is shared, total = N/256*M+N @@ -59,40 +59,36 @@ CM_INLINE void get_mn(uint& id_wg_m, uint& id_wg_n, uint M, uint N, int slice_no } } -// _GENX_MAIN_ void gemm_qk +// _GENX_MAIN_ void gemm_qk( extern "C" _GENX_MAIN_ void KERNEL_NAME( svmptr_t key_cache ATTR, svmptr_t query ATTR, svmptr_t block_indices ATTR, svmptr_t block_indices_begins ATTR, svmptr_t kq_max_wg ATTR, - svmptr_t kq_exp_partial_sum ATTR, + svmptr_t kq_exp_partial_sum ATTR, uint M, uint N, uint K, uint query_stride, int slice_no, int slice, uint q_start_strided) { const uint BLOCK_WG_M = BLOCK_SG_M * SG_M; const uint BLOCK_WG_N = BLOCK_SG_N * SG_N; const uint size_slm_b = 0; - uint hq = cm_group_id(2); + uint hq = cm_group_id(2) * WALK_HQ; + hq += cm_group_id(0) & (WALK_HQ - 1); + if (hq >= HQ) return; uint hk = hq / (HQ / HK); - const uint slm_size = SG_N * BLOCK_WG_M * sizeof(half); + const uint slm_size = SG_N * BLOCK_WG_M * sizeof(SOFTMAX_TYPE); cm_slm_init(slm_size); auto slm = cm_slm_alloc(slm_size); - // printf("=============== gemm_qk ===================="); static_assert(HQ % HK == 0, "HQ must be multiple of HK"); uint id_wg_m, id_wg_n; get_mn(id_wg_m, id_wg_n, M, N, slice_no, slice, BLOCK_WG_M, BLOCK_WG_N); - // auto wg_id_N = cm_group_id(0); - // auto wg_lid_N = cm_local_id(0); - // auto wg_id_M = cm_group_id(1); - // auto wg_lid_M = cm_local_id(1); - // printf("=============== wgN:%d.%d wgM:%d.%d hq %d: id_wg_m %d, id_wg_n %d, %p, %p, %p, %p; test: %p,%p, M %d, N %d, K %d, %d, %d; %d, %d\n", - // wg_id_N, wg_lid_N, wg_id_M, wg_lid_M, hq, - // id_wg_m, id_wg_n, query, key_cache, kq_max_wg, kq_exp_partial_sum, - // key_cache+sizeof(half), query+sizeof(half), - // M, N, K, query_stride, q_start_strided, - // KV_BLOCK_SIZE, HEAD_SIZE_KEY); + auto wg_id_N = cm_group_id(0); + auto wg_lid_N = cm_local_id(0); + auto wg_id_M = cm_group_id(1); + auto wg_lid_M = cm_local_id(1); + // printf("=============================================== wgN:%d.%d wgM:%d.%d hq %d =============================================== \n", wg_id_N, wg_lid_N, wg_id_M, wg_lid_M, hq); // key cache: [block, HQ, KV_BLOCK_SIZE, HEAD_SIZE_KEY] #if USE_INT8 @@ -108,12 +104,12 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( // kq_exp_partial_sum: [hq, m_pad, n_groups*BLOCK_WG_M/(BLOCK_SIZE/STRIDE)] uint m_pad = (M + BLOCK_WG_M - 1) / BLOCK_WG_M * BLOCK_WG_M; uint n_groups = (N + BLOCK_WG_N - 1) / BLOCK_WG_N; - kq_max_wg += hq * n_groups * m_pad * (uint)sizeof(half); + kq_max_wg += hq * n_groups * m_pad * (uint)sizeof(SOFTMAX_TYPE); const uint sum_per_n_token_in_block = BLOCK_SIZE / STRIDE; const uint n_after_sum_in_group = BLOCK_WG_N / sum_per_n_token_in_block; const uint n_after_sum_pad = n_after_sum_in_group * n_groups; - kq_exp_partial_sum += hq * n_after_sum_pad * m_pad * (uint)sizeof(half); + kq_exp_partial_sum += hq * n_after_sum_pad * m_pad * (uint)sizeof(SOFTMAX_TYPE); #define CONCAT_IMPL(a, b) gemm_qk_ ##a ##x ##b ##_xe2 #define CONCAT(x, y) CONCAT_IMPL(x, y) From 98a4ecd1499c077a37832de8af34fd3147128dd0 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 5 Sep 2025 15:10:34 +0800 Subject: [PATCH 11/96] fix block_mask not fully initialized issue. --- .../intel_gpu/src/graph/impls/cm/include/find_block.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp index 5af1329a44fd6d..519972b1c473d3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -24,7 +24,6 @@ #include #include "sort.hpp" - #define MYMIN(x, y) ((x) < (y) ? (x) : (y)) #define MYCONCAT(x, y) x ## y @@ -68,6 +67,7 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p int m_end = MYMIN(m_start + TOKEN_SHARE_MAX, q_stride); int valid_m = m_end - m_start; if (valid_m == 0) return; + block_mask += m_block * k_block_pad; lsc::block_2d_desc desc_sum{ kq_exp_partial_sum, (uint)valid_m - 1, (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), 0, 0 }; { @@ -120,6 +120,7 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p #if DEBUG_ACC == 1 kq_sum += m_block * k_block_pad * (int)sizeof(half); #endif + vector zero = 0; for (int j = 0; j < k_block_pad; j += TOKEN_SHARE_MAX) { #if CUR_TYPE == IS_float cm_load(data.select(0, 0).format(), desc_sum); @@ -139,6 +140,7 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p #if DEBUG_ACC == 1 cm_ptr_store((int*)kq_sum, j * (int)sizeof(half), data_half.format()); #endif + cm_ptr_store((int*)block_mask, j, zero.format()); } auto thresh_act = cm_sum(sum_m_after_add) * thresh; @@ -148,7 +150,6 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p // line 3: sorted index // line 5: sorted tmp // line 6: accumalative score - block_mask += m_block * k_block_pad; auto score = kq_exp_partial_sum + 0 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); auto sorted_value = kq_exp_partial_sum + 1 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); auto sorted_index = kq_exp_partial_sum + 3 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); From 5af3330fccf46fe4ea183f6d6ededed73a9e0fa6 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 8 Sep 2025 14:11:53 +0800 Subject: [PATCH 12/96] fix of find_block --- src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp | 2 +- src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp | 2 +- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp index ad8d2696356cb1..e3cd5acbe78fb9 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp @@ -737,7 +737,7 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { // fill -inf -> max in group, 0 -> exp_sum to make compensation work { // current max -> mem - vector max_m = -60000; + vector max_m = -60000; // kq_max_wg: [b, hq, N/BLOCK_WG_N, M_aligned] uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(SOFTMAX_TYPE); cm_ptr_store((int*)kq_max_wg, offset, max_m.format()); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp index 519972b1c473d3..f46f2f2cce2795 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -178,7 +178,7 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p acc_score_p[0] = 0; acc_score_p[1] = 0; #endif - for (int j = 2; j < k_block_pad - 2; j++) { + for (int j = 2; j < k_block_pad; j++) { #if DEBUG_ACC == 1 acc_score_p[j] = sum_cur; #endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 03d8e6369ac673..e17aa4e88a61f5 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -42,7 +42,7 @@ constexpr uint SG_N = 8; constexpr uint BLOCK_WG_M = BLOCK_SG_M * SG_M; constexpr uint BLOCK_WG_N = BLOCK_SG_N * SG_N; constexpr int STRIDE = 16; -constexpr float THRESH = 0.9; +constexpr float THRESH = 1.0; enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; struct PagedAttentionRuntimeParams : public ImplRuntimeParams { From 4f9ed286e77eeaf7cccc1eda703ef815f6ebe90f Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Tue, 9 Sep 2025 11:06:18 +0800 Subject: [PATCH 13/96] xatten: fix accuacy problem caused by debug --- .../src/graph/impls/cm/include/find_block.hpp | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp index f46f2f2cce2795..e7a3abc4ebc8fd 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -133,9 +133,9 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p data.row(0) += data.row(i) * inv_sum_v[i]; } desc_sum.set_block_x(desc_sum.get_block_x() + TOKEN_SHARE_MAX); - sum_m_after_add += data.row(0); // the sum type is always half in the reference code of the paper: https://github.com/mit-han-lab/x-attention/blob/fb2ac200a23d20568f7d166ddb5ee247926d2b2b/xattn/src/kernels.py#L248 vector data_half = data.row(0); + sum_m_after_add += data_half; cm_ptr_store((int*)kq_exp_partial_sum, j * (int)sizeof(half), data_half.format()); #if DEBUG_ACC == 1 cm_ptr_store((int*)kq_sum, j * (int)sizeof(half), data_half.format()); @@ -160,7 +160,7 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p auto score_p = (half*)score; half s_0 = score_p[0]; half s_causal = score_p[causal_start_index + m_block]; - half s_sum = s_0; + float s_sum = s_0; if (causal_start_index + m_block) s_sum += s_causal; score_p[0] = -1; score_p[causal_start_index + m_block] = -1; @@ -178,19 +178,27 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p acc_score_p[0] = 0; acc_score_p[1] = 0; #endif - for (int j = 2; j < k_block_pad; j++) { + int j; + for (j = 2; j < k_block_pad; j++) { #if DEBUG_ACC == 1 acc_score_p[j] = sum_cur; #endif if (sum_cur < thresh_act) { block_mask_p[sorted_index_p[j]] = 1; } else { -#if DEBUG_ACC != 1 break; -#endif } sum_cur += sorted_value_p[j]; } +#if DEBUG_ACC == 1 + for (; j < k_block_pad; j++) { + acc_score_p[j] = sum_cur; + sum_cur += sorted_value_p[j]; + } +#endif + + // for (int j = causal_start_index + m_block + 1; j < k_block_pad; j++) + // block_mask_p[j] = 0; #else @@ -204,7 +212,8 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p #if DEBUG_ACC == 1 acc_score_p[0] = 0; #endif - for (int j = 0; j < k_block_pad - 1; j++) { + int j; + for (j = 0; j < k_block_pad - 1; j++) { sum_cur += sorted_value_p[j]; #if DEBUG_ACC == 1 acc_score_p[j + 1] = sum_cur; @@ -213,10 +222,15 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p block_mask_p[sorted_index_p[j]] = 1; } else { block_mask_p[sorted_index_p[j]] = 1; -#if DEBUG_ACC != 1 break; -#endif } } +#if DEBUG_ACC == 1 + for (j = j + 1; j < k_block_pad - 1; j++) { + sum_cur += sorted_value_p[j]; + acc_score_p[j + 1] = sum_cur; + } +#endif + #endif } From d35f4fb7fda7117d4403b062aea8ba69e15eb5ba Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Wed, 10 Sep 2025 11:18:11 +0800 Subject: [PATCH 14/96] use int32 to store float INV_S to align python version accuracy --- .../src/graph/impls/cm/include/estimate.hpp | 16 ++++++++++++++-- .../src/graph/impls/cm/include/find_block.hpp | 4 +++- .../src/graph/impls/cm/paged_attention_gen.cpp | 6 ++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp index e3cd5acbe78fb9..e34351db241d8d 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp @@ -525,12 +525,18 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { cm_sbarrier(0); matrix acc_half; + union { + float f; + int y; + } i2f; + i2f.y = INV_S; + const float inv_s = i2f.f; #pragma unroll for (uint reg_m = 0; reg_m < REG_M; reg_m++) { #pragma unroll for (int reg_n = 0; reg_n < REG_N; reg_n++) { acc_half.select(reg_m * BLOCK_REG_M, reg_n * BLOCK_REG_N) = - acc.row(reg_m * REG_N + reg_n) * float{INV_S}; + acc.row(reg_m * REG_N + reg_n) * inv_s; } } @@ -1009,12 +1015,18 @@ uint M, uint N, uint K, uint query_stride, uint q_start_strided) { cm_sbarrier(0); matrix acc_half; + union { + float f; + int y; + } i2f; + i2f.y = INV_S; + const float inv_s = i2f.f; #pragma unroll for (uint reg_m = 0; reg_m < REG_M; reg_m++) { #pragma unroll for (int reg_n = 0; reg_n < REG_N; reg_n++) { acc_half.select(reg_m * BLOCK_REG_M, reg_n * BLOCK_REG_N) = - acc.row(reg_m * REG_N + reg_n) * float{INV_S}; + acc.row(reg_m * REG_N + reg_n) * inv_s; } } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp index e7a3abc4ebc8fd..56e32a4c955fbd 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -184,7 +184,9 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p acc_score_p[j] = sum_cur; #endif if (sum_cur < thresh_act) { - block_mask_p[sorted_index_p[j]] = 1; + auto k_idx = sorted_index_p[j]; + if (k_idx <= causal_start_index + m_block) + block_mask_p[k_idx] = 1; } else { break; } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 3b2bf51c52c75f..5204fbd447e71d 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -642,7 +642,9 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp auto desc = params.typed_desc(); - const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)) / STRIDE; + const float scale_factor = 1.0f / std::sqrt(static_cast(desc->k_head_size)) / STRIDE; + int scale_factor_i; + std::memcpy(static_cast(&scale_factor_i), &scale_factor, sizeof(scale_factor)); jit.make("STRIDE", STRIDE); jit.make("HQ", desc->heads_num); @@ -654,7 +656,7 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp jit.make("BLOCK_SG_N", BLOCK_SG_N); jit.make("BLOCK_SIZE", get_xattn_block_size()); jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); - jit.add(make_jit_constant("INV_S", scale_factor)); + jit.add(make_jit_constant("INV_S", scale_factor_i)); jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... jit.make("WALK_HQ", desc->heads_num != desc->kv_heads_num ? 2 : 1); From 4e25a4a9614a8007df64f88faf1b1c6a91536c49 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 10 Sep 2025 13:09:49 +0800 Subject: [PATCH 15/96] OV_GPU_XATTN_BLOCK_SIZE and OV_GPU_XATTN_THRESH --- .../intel_gpu/runtime/internal_properties.hpp | 1 + .../include/intel_gpu/runtime/options.inl | 1 + .../src/graph/impls/cm/paged_attention.cpp | 31 ++++++++-------- .../graph/impls/cm/paged_attention_gen.cpp | 35 +++++++++++++------ .../graph/impls/cm/paged_attention_gen.hpp | 9 +++-- 5 files changed, 47 insertions(+), 30 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp index 536eda8e7b06d5..433e5da8c790b9 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp @@ -172,6 +172,7 @@ static constexpr Property asym_dynamic_quantiz static constexpr Property shape_predictor_settings{"GPU_SHAPE_PREDICTOR_SETTINGS"}; static constexpr Property, ov::PropertyMutability::RW> load_dump_raw_binary{"GPU_LOAD_DUMP_RAW_BINARY"}; static constexpr Property could_use_flashattn_v2{"GPU_COULD_USE_FLASHATTN_V2"}; +static constexpr Property xattention_block_size{"GPU_XATTN_BLOCK_SIZE"}; static constexpr Property dynamic_quantization_group_size_max{"GPU_DYNAMIC_QUANTIZATION_GROUP_SIZE_MAX"}; static constexpr Property validate_output_buffer{"VALIDATE_OUTPUT_BUFFER"}; } // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl index 1546ea2a7c7570..2f923fb8c1637f 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl @@ -55,6 +55,7 @@ OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, asym_dynamic_quantization, fals OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, could_use_flashattn_v2, true, "Enable/Disable SDPA primitive executing with FlashAttenV2 online softmax tricks.") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, dynamic_quantization_threshold, 64, "Apply dynamic quantization only when batch size is larger than this value in OneDNN") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, weightless_attr, nullptr, "Used to configure ov::WeightlessCacheAttribute for constants that are not loaded from a .bin file. This typically applies to non-IR inputs (e.g., ORT)") +OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, xattention_block_size, 128, "block size for X-Attention sparse.") OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, help, false, "Print help message for all config options") OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, verbose, 0, "Enable logging for debugging purposes. The higher value the more verbose output. 0 - Disabled, 4 - Maximum verbosity") diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index ce9d5b601b5d08..4dffa23a467188 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -44,10 +44,11 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { add_stage(pa_single_token, params); add_stage(pa_single_token_finalization, params); add_stage(pa_multi_token, params); -#if PA_SPARSE_BLOCK_SIZE > 1 - add_stage(xattn_estimate_gemmqk, params); - add_stage(xattn_estimate_find_block, params); -#endif + const size_t xattn_block_size = get_xattn_block_size(params); + if (xattn_block_size > 1) { + add_stage(xattn_estimate_gemmqk, params); + add_stage(xattn_estimate_find_block, params); + } } void update_rt_params(const primitive_inst& instance) override { @@ -89,16 +90,16 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { res_event = {execute_stage(res_event, instance, kv_cache_update)}; if (rt_params->stage == PagedAttentionStage::PREFILL || rt_params->stage == PagedAttentionStage::MIXED) { -#if PA_SPARSE_BLOCK_SIZE > 1 - cldnn::stream& stream = instance.get_network().get_stream(); - stream.finish(); - res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; - stream.finish(); - // std::cout << "finish xattn_estimate_gemmqk!\n"; - res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; - stream.finish(); - // std::cout << "finish xattn_estimate_find_block!\n"; -#endif + if (has_stage(xattn_estimate_gemmqk)) { + // cldnn::stream& stream = instance.get_network().get_stream(); + // stream.finish(); + res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; + // stream.finish(); + // std::cout << "finish xattn_estimate_gemmqk!\n"; + res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; + // stream.finish(); + // std::cout << "finish xattn_estimate_find_block!\n"; + } res_event = {execute_stage(res_event, instance, pa_multi_token)}; } else if (rt_params->stage == PagedAttentionStage::GENERATE) { res_event = {execute_stage(res_event, instance, pa_single_token)}; @@ -164,7 +165,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg - const size_t block_size = get_xattn_block_size(); + const size_t block_size = get_xattn_block_size(params); if (block_size > 1) { OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); const size_t sum_per_n_token_in_block = block_size / STRIDE; // FIXME diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 5204fbd447e71d..efde91377f1ce3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -186,6 +186,14 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx) { return paged_attention_past_len; } +const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx) { + (void) seq_idx; // TODO + + static const char* env = std::getenv("OV_GPU_XATTN_THRESH"); + static const float thresh = env ? std::strtof(env, nullptr) : 0.9; + return thresh; +} + inline void dump_block_indices_begins(const kernel_impl_params& params) { const auto& input_mem = params.memory_deps; const auto mem = input_mem.at(PagedAttentionInputIdx::BLOCK_INDICES_BEGINS); @@ -394,9 +402,11 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block_indices args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block_indices_begins args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins -#if PA_SPARSE_BLOCK_SIZE > 1 - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask -#endif + + const size_t block_size = get_xattn_block_size(params); + if (block_size > 1) + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len @@ -409,13 +419,15 @@ JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_i const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)); auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + const size_t xattn_block_size = get_xattn_block_size(params); + jit.make("CMFLA_NUM_HEADS", desc->heads_num); jit.make("CMFLA_NUM_KV_HEADS", desc->kv_heads_num); jit.make("CMFLA_HEAD_SIZE", desc->k_head_size); jit.add(make_jit_constant("CMFLA_SCALE_FACTOR", scale_factor)); jit.make("CMFLA_IS_CAUSAL", 1); jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE); - jit.make("SPARSE_BLOCK_SIZE", PA_SPARSE_BLOCK_SIZE); + jit.make("SPARSE_BLOCK_SIZE", xattn_block_size); jit.make("Q_STEP", get_q_step(xe_arch, true)); // for (auto& it : jit) { // std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl; @@ -654,7 +666,7 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp jit.make("SG_N", SG_N); jit.make("BLOCK_SG_M", BLOCK_SG_M); jit.make("BLOCK_SG_N", BLOCK_SG_N); - jit.make("BLOCK_SIZE", get_xattn_block_size()); + jit.make("BLOCK_SIZE", get_xattn_block_size(params)); jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); jit.add(make_jit_constant("INV_S", scale_factor_i)); jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); @@ -837,7 +849,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { const uint wg_k = BLOCK_WG_M; const uint wg_q = BLOCK_WG_N; - const size_t block_size = get_xattn_block_size(); + const size_t block_size = get_xattn_block_size(params); OPENVINO_ASSERT(wg_k % block_size == 0, "wg_k should be multiple of block_size then there is no tails from block_size"); OPENVINO_ASSERT(wg_q % block_size == 0, "wg_q should be multiple of block_size then there is no tails from block_size"); @@ -864,6 +876,8 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { const uint q_block = ceil_div(q_stride, sum_per_n_token_in_block); const uint k_block = ceil_div(k_stride, sum_per_n_token_in_block); + const float xattn_thresh = get_xattn_thresh(params, 0); // TODO: seq_idx + wgs.global = {q_stride_pad / sum_per_n_token_in_block, heads_num, 1}; wgs.local = {1, 1, 1}; @@ -871,10 +885,11 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { std::vector scaler_value = {q_stride, q_stride_pad, k_block_pad, k_block - q_block}; scalars.resize(scaler_value.size() + 1); - if (DEBUG_ENABLED) { // Debug + if (1 || DEBUG_ENABLED) { // Debug std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " - << "k_block: " << k_block << ", q_block: " << q_block - << "q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " + << "xattn_thresh : " << xattn_thresh + << " k_block: " << k_block << ", q_block: " << q_block + << " q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } @@ -884,7 +899,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { scalars[i].v.u32 = static_cast(scaler_value[i]); } scalars[scaler_value.size()].t = ScalarDescriptor::Types::FLOAT32; // the last is for thresh with f32 dtype - scalars[scaler_value.size()].v.f32 = static_cast(THRESH); + scalars[scaler_value.size()].v.f32 = static_cast(xattn_thresh); }}; } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index e17aa4e88a61f5..6c69173f296527 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -32,8 +32,6 @@ constexpr auto get_pa_build_options() { // BLOCK_SIZE can be 16/32/64/128/256 #define PA_KV_CACHE_BLOCK_SIZE 256 -// sparse attention block size is set to 1 to disable sparse attention support in CM kernels -#define PA_SPARSE_BLOCK_SIZE 128 constexpr uint BLOCK_SG_M = 64; constexpr uint BLOCK_SG_N = 32; @@ -42,7 +40,6 @@ constexpr uint SG_N = 8; constexpr uint BLOCK_WG_M = BLOCK_SG_M * SG_M; constexpr uint BLOCK_WG_N = BLOCK_SG_N * SG_N; constexpr int STRIDE = 16; -constexpr float THRESH = 1.0; enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; struct PagedAttentionRuntimeParams : public ImplRuntimeParams { @@ -64,8 +61,10 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); size_t get_partition_size(); size_t get_partition_num(const size_t kv_len); - -inline size_t get_xattn_block_size() { return PA_SPARSE_BLOCK_SIZE; } +const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx); +inline size_t get_xattn_block_size(const kernel_impl_params& impl_param) { + return impl_param.get_program().get_config().get_xattention_block_size(); + } class PagedAttentionGeneratorBase : public KernelGenerator { public: From c3c87b7e2ab20beb9e4e5583fb31a4e26e29604c Mon Sep 17 00:00:00 2001 From: "Li, Tingqian" Date: Wed, 10 Sep 2025 17:49:22 +0800 Subject: [PATCH 16/96] fix building error on windows. --- .../src/graph/impls/cm/paged_attention.cpp | 10 +++--- .../graph/impls/cm/paged_attention_gen.cpp | 36 +++++++++---------- .../graph/impls/cm/paged_attention_gen.hpp | 12 +++---- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 4dffa23a467188..3475785c3258b5 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -157,8 +157,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto out_shape = params.output_layouts[0].get_shape(); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` - const uint N = kv_len / STRIDE; + const uint32_t M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const uint32_t N = kv_len / STRIDE; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); @@ -169,9 +169,9 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { if (block_size > 1) { OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); const size_t sum_per_n_token_in_block = block_size / STRIDE; // FIXME - const uint sum_per_token_in_block = block_size / STRIDE; - const uint k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; - const uint k_block_pad = k_block_in_group * N_kq_groups; + const uint32_t sum_per_token_in_block = block_size / STRIDE; + const uint32_t k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const uint32_t k_block_pad = k_block_in_group * N_kq_groups; auto count_kq_exp_partial_sum = static_cast(desc->heads_num * q_stride_pad * k_block_pad); internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index efde91377f1ce3..956a35c84ac6fe 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -754,9 +754,9 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { auto out_shape = params.output_layouts[0].get_shape(); const size_t q_len = out_shape[0]; - const uint M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` - const uint N = kv_len / STRIDE; - const uint K = STRIDE * head_size; + const uint32_t M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const uint32_t N = kv_len / STRIDE; + const uint32_t K = STRIDE * head_size; auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); @@ -768,8 +768,8 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { } return pitch; }; - const uint query_pitch = get_simple_pitch(querry_layout) * STRIDE; - const uint slice_no = 0, slice = 0; + const uint32_t query_pitch = get_simple_pitch(querry_layout) * STRIDE; + const uint32_t slice_no = 0, slice = 0; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); @@ -781,7 +781,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { wgs.global = {N_kq_groups * (q_stride_pad / BLOCK_WG_M) * SG_N * WALK_HQ, SG_M, heads_num / WALK_HQ}; wgs.local = {SG_N, SG_M, 1}; - const uint q_start_strided = N - M; + const uint32_t q_start_strided = N - M; OPENVINO_ASSERT(N >= M, "length of key cache must be greater or equal than query"); auto& scalars = kd.params.scalars; @@ -808,7 +808,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { scalars[i].v.s32 = static_cast(scaler_value[i]); } else { scalars[i].t = ScalarDescriptor::Types::UINT32; - scalars[i].v.u32 = static_cast(scaler_value[i]); + scalars[i].v.u32 = static_cast(scaler_value[i]); } } }}; @@ -847,8 +847,8 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { assert(rt_params != nullptr); - const uint wg_k = BLOCK_WG_M; - const uint wg_q = BLOCK_WG_N; + const uint32_t wg_k = BLOCK_WG_M; + const uint32_t wg_q = BLOCK_WG_N; const size_t block_size = get_xattn_block_size(params); OPENVINO_ASSERT(wg_k % block_size == 0, "wg_k should be multiple of block_size then there is no tails from block_size"); OPENVINO_ASSERT(wg_q % block_size == 0, "wg_q should be multiple of block_size then there is no tails from block_size"); @@ -862,19 +862,19 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { auto out_shape = params.output_layouts[0].get_shape(); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` - const uint N = kv_len / STRIDE; - const uint q_stride = M; - const uint k_stride = N; + const uint32_t M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const uint32_t N = kv_len / STRIDE; + const uint32_t q_stride = M; + const uint32_t k_stride = N; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); - const uint sum_per_token_in_block = block_size / STRIDE; - const uint k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; - const uint k_block_pad = k_block_in_group * N_kq_groups; + const uint32_t sum_per_token_in_block = block_size / STRIDE; + const uint32_t k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const uint32_t k_block_pad = k_block_in_group * N_kq_groups; - const uint q_block = ceil_div(q_stride, sum_per_n_token_in_block); - const uint k_block = ceil_div(k_stride, sum_per_n_token_in_block); + const uint32_t q_block = ceil_div(q_stride, sum_per_n_token_in_block); + const uint32_t k_block = ceil_div(k_stride, sum_per_n_token_in_block); const float xattn_thresh = get_xattn_thresh(params, 0); // TODO: seq_idx diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 6c69173f296527..fadd095e688180 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -33,12 +33,12 @@ constexpr auto get_pa_build_options() { // BLOCK_SIZE can be 16/32/64/128/256 #define PA_KV_CACHE_BLOCK_SIZE 256 -constexpr uint BLOCK_SG_M = 64; -constexpr uint BLOCK_SG_N = 32; -constexpr uint SG_M = 4; -constexpr uint SG_N = 8; -constexpr uint BLOCK_WG_M = BLOCK_SG_M * SG_M; -constexpr uint BLOCK_WG_N = BLOCK_SG_N * SG_N; +constexpr uint32_t BLOCK_SG_M = 64; +constexpr uint32_t BLOCK_SG_N = 32; +constexpr uint32_t SG_M = 4; +constexpr uint32_t SG_N = 8; +constexpr uint32_t BLOCK_WG_M = BLOCK_SG_M * SG_M; +constexpr uint32_t BLOCK_WG_N = BLOCK_SG_N * SG_N; constexpr int STRIDE = 16; enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; From 76685f08772f4e7e5b5d70703b155e06520a681f Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 12 Sep 2025 15:09:25 +0800 Subject: [PATCH 17/96] process tail in find_block --- .../src/graph/impls/cm/include/find_block.hpp | 13 +++++++++++-- .../src/graph/impls/cm/paged_attention.cpp | 4 ++-- .../graph/impls/cm/paged_attention_gen.cpp | 19 +++++++++++-------- .../src/graph/impls/cm/xattn_find_block.cm | 19 ++++++++++++++----- 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp index 56e32a4c955fbd..a4c57a19d1e8e7 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -35,7 +35,7 @@ // kq_max_wg: [b, hq, n_groups, q_stride_pad] // kq_exp_partial_sum: [b, hq, q_stride_pad, k_block_pad] // kq_sum: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] -CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_partial_sum, svmptr_t block_mask, uint q_stride, uint q_stride_pad, uint k_block_pad, float thresh, uint causal_start_index +CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_partial_sum, svmptr_t block_mask, uint q_len, uint q_stride, uint q_stride_pad, uint k_block_pad, float thresh, uint causal_start_index #if DEBUG_ACC == 1 , svmptr_t kq_sum #endif @@ -66,8 +66,17 @@ CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_p int m_start = MYMIN(m, q_stride); int m_end = MYMIN(m_start + TOKEN_SHARE_MAX, q_stride); int valid_m = m_end - m_start; - if (valid_m == 0) return; block_mask += m_block * k_block_pad; + if (valid_m == 0) { + // case for tails: q is not inside mask, aka q % BLOCK_SIZE < STRIDE + if (m * STRIDE < q_len) { + vector one = 1; + for (int j = 0; j < k_block_pad; j += TOKEN_SHARE_MAX) { + cm_ptr_store((int*)block_mask, j, one.format()); + } + } + return; + } lsc::block_2d_desc desc_sum{ kq_exp_partial_sum, (uint)valid_m - 1, (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), 0, 0 }; { diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 3475785c3258b5..76b77570103902 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -168,14 +168,14 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const size_t block_size = get_xattn_block_size(params); if (block_size > 1) { OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); - const size_t sum_per_n_token_in_block = block_size / STRIDE; // FIXME + const uint32_t q_block_pad = ceil_div(q_len, block_size); const uint32_t sum_per_token_in_block = block_size / STRIDE; const uint32_t k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; const uint32_t k_block_pad = k_block_in_group * N_kq_groups; auto count_kq_exp_partial_sum = static_cast(desc->heads_num * q_stride_pad * k_block_pad); internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum - auto count_elements_mask = static_cast(desc->heads_num * (q_stride_pad / sum_per_n_token_in_block) * k_block_pad); + auto count_elements_mask = static_cast(desc->heads_num * q_block_pad * k_block_pad); internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask } } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 956a35c84ac6fe..d5ea6f16516692 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -828,11 +828,13 @@ Arguments XAttentionEstimateFindBlock::get_arguments_desc(const kernel_impl_para args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // block_mask // scalar - args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_stride - args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_stride_pad - args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad - args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // causal_start_index - args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // thresh + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_stride + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // q_stride_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // q_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // k_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 5}); // causal_start_index + args.push_back({ArgumentDescriptor::Types::SCALAR, 6}); // thresh return args; } @@ -872,20 +874,21 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { const uint32_t sum_per_token_in_block = block_size / STRIDE; const uint32_t k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; const uint32_t k_block_pad = k_block_in_group * N_kq_groups; + const uint32_t q_block_pad = ceil_div(q_len, block_size); const uint32_t q_block = ceil_div(q_stride, sum_per_n_token_in_block); const uint32_t k_block = ceil_div(k_stride, sum_per_n_token_in_block); const float xattn_thresh = get_xattn_thresh(params, 0); // TODO: seq_idx - wgs.global = {q_stride_pad / sum_per_n_token_in_block, heads_num, 1}; + wgs.global = {q_block_pad, heads_num, 1}; wgs.local = {1, 1, 1}; auto& scalars = kd.params.scalars; - std::vector scaler_value = {q_stride, q_stride_pad, k_block_pad, k_block - q_block}; + std::vector scaler_value = {q_len, q_stride, q_stride_pad, q_block_pad, k_block_pad, k_block - q_block}; scalars.resize(scaler_value.size() + 1); - if (1 || DEBUG_ENABLED) { // Debug + if (DEBUG_ENABLED) { // Debug std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " << "xattn_thresh : " << xattn_thresh << " k_block: " << k_block << ", q_block: " << q_block diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm index 2201d2b088eebb..d3a61789b312ea 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm @@ -24,7 +24,16 @@ namespace KERNEL_NAME { // _GENX_MAIN_ void find_block( extern "C" _GENX_MAIN_ void KERNEL_NAME( - svmptr_t kq_max_wg ATTR, svmptr_t kq_exp_partial_sum ATTR, svmptr_t block_mask ATTR, uint q_stride, uint q_stride_pad, uint k_block_pad, uint causal_start_index, float thresh + svmptr_t kq_max_wg ATTR, + svmptr_t kq_exp_partial_sum ATTR, + svmptr_t block_mask ATTR, + uint q_len, + uint q_stride, + uint q_stride_pad, + uint q_block_pad, + uint k_block_pad, + uint causal_start_index, + float thresh #if DEBUG_ACC == 1 , svmptr_t kq_sum ATTR #endif @@ -34,7 +43,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( // kq_sum: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] // block_mask: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] // [1, 32, 256], [1, 32, 64, 256], [1, 32, 256, 64 * 16], A_sum:[1, 32, 32, 64 * 16] - // global: [q_stride_pad/TOKEN_IN_BLOCK, hq, b] + // global: [q_block_pad, hq, b] const int TOKEN_IN_BLOCK = BLOCK_SIZE / STRIDE; const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; uint m = cm_group_id(0); @@ -43,15 +52,15 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( kq_max_wg += (b * HQ + hq) * (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad * (uint)sizeof(SOFTMAX_TYPE); kq_exp_partial_sum += (b * HQ + hq) * q_stride_pad * k_block_pad * (uint)sizeof(SOFTMAX_TYPE); #if DEBUG_ACC == 1 - kq_sum += (b * HQ + hq) * q_stride_pad / TOKEN_IN_BLOCK * k_block_pad * (uint)sizeof(half); + kq_sum += (b * HQ + hq) * (q_stride_pad / TOKEN_IN_BLOCK) * k_block_pad * (uint)sizeof(half); #endif - block_mask += (b * HQ + hq) * q_stride_pad / TOKEN_IN_BLOCK * k_block_pad; + block_mask += (b * HQ + hq) * q_block_pad * k_block_pad; const uint slm_size = 32 * 16 * sizeof(ushort); cm_slm_init(slm_size); auto slm = cm_slm_alloc(slm_size); - find(slm, m, kq_max_wg, kq_exp_partial_sum, block_mask, q_stride, q_stride_pad, k_block_pad, thresh, causal_start_index + find(slm, m, kq_max_wg, kq_exp_partial_sum, block_mask, q_len, q_stride, q_stride_pad, k_block_pad, thresh, causal_start_index #if DEBUG_ACC == 1 , kq_sum #endif From c5bdcf986a8c16aa2e743b6b0cf94a109d453db9 Mon Sep 17 00:00:00 2001 From: "river.li" Date: Tue, 9 Sep 2025 11:05:16 +0800 Subject: [PATCH 18/96] Fix f16 accuracy issue and optimize 2nd token to improve 5% --- .../src/graph/impls/cm/pa_single_token.cm | 443 +++++++++++++----- .../graph/impls/cm/paged_attention_gen.cpp | 9 +- .../test_cases/paged_attention_gpu_test.cpp | 3 +- 3 files changed, 339 insertions(+), 116 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index 59978e1d615666..bb7ea162349d72 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -40,9 +40,10 @@ #define KV_STEP #define WG_SIZE #define XE_ARCH +#define KV_CACHE_COMPRESSION #endif -#define PARTITION_SUBBLOCK_NUM (KV_PARTITION_SIZE / KV_STEP) +#define KV_PARTITION_STEP_NUM (KV_PARTITION_SIZE / KV_STEP) #define KV_SCALE_ZP_SIZE 0 // 4: scale/zp size @@ -60,14 +61,61 @@ void show(matrix mat) { } printf("]\n"); } + +template +void show_u8(matrix mat) { + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%4d", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} + +template +void show(vector vec) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", vec[n]); + } + printf("]\n"); +} #endif +#define Q_SLICE_NUM (HEADS_NUM / KV_HEADS_NUM) +#if Q_SLICE_NUM > 8 || Q_SLICE_NUM == 1 +#define Q_RepeatCount 1 +#else +#define Q_RepeatCount Q_SLICE_NUM +#endif + +#if KV_CACHE_COMPRESSION + // scale/zp is half-precision, so size = 2 * 2 = 4 bytes + #define KV_SCALE_ZP_SIZE 4 // scale/zp bytes + #define KV_ELEMENT_TYPE uint8_t +#else + #define KV_SCALE_ZP_SIZE 0 // no scale/zp + #define KV_ELEMENT_TYPE half +#endif + +//prepack [K, N] to [K/2, N, 2] layout. +template +inline void prepackAsVNNIWidth2(matrix_ref input, matrix_ref out) { + #pragma unroll + for (int r = 0; r < K/2; r++) { + out.row(r).select(0) = input.row(r*2); + out.row(r).select(1) = input.row(r*2+1); + } +} + extern "C" _GENX_MAIN_ void KERNEL_NAME( // extern "C" _GENX_MAIN_ void cm_sdpa_2nd( half* query [[type("svmptr_t")]], - half* key [[type("svmptr_t")]], - half* value [[type("svmptr_t")]], + KV_ELEMENT_TYPE* key [[type("svmptr_t")]], + KV_ELEMENT_TYPE* value [[type("svmptr_t")]], int* past_lens [[type("svmptr_t")]], int* block_indices [[type("svmptr_t")]], int* block_indices_begins [[type("svmptr_t")]], @@ -91,104 +139,167 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( //# Each WG processes a partition, which is KV_PARTITION_SIZE long and multiple of KV_BLOCK_SIZE. //# KV_BLOCK_SIZE can be 32/64/128/256, etc. const auto seq_idx = cm_global_id(0); - const auto head_num_idx = cm_global_id(1); - const auto kv_head_num_idx = head_num_idx / (HEADS_NUM/KV_HEADS_NUM); - //# const auto wg_local_id = cm_local_id(2); + const auto kv_head_num_idx = cm_global_id(1); + const auto head_num_idx = kv_head_num_idx * (HEADS_NUM/KV_HEADS_NUM); //# KV_PARTITION_SIZE --> EU thread const auto wg_thread_id = cm_global_id(2); const uint kv_partition_num = cm_group_count(2); - const uint partition_idx = cm_group_id(2); - - // # const uint subsequence_idx = gws_subseq_mapping[seq_idx]; - const uint subsequence_idx = seq_idx; + const uint kv_partition_idx = cm_group_id(2); + //# const uint subsequence_idx = gws_subseq_mapping[seq_idx]; + //# const uint subsequence_idx = seq_idx; //# const uint subsequence_begin = subsequence_begins[subsequence_idx]; //# const uint subsequence_end = subsequence_begins[subsequence_idx + 1]; - const uint kv_len = past_lens[subsequence_idx] + 1; - const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * (KV_PARTITION_SIZE / KV_BLOCK_SIZE); + const uint kv_len = past_lens[seq_idx] + 1; + const uint start_block_idx = block_indices_begins[seq_idx] + kv_partition_idx * (KV_PARTITION_SIZE / KV_BLOCK_SIZE); - if(partition_idx * KV_PARTITION_SIZE > kv_len) { - // printf("WG exit: partition_idx=%d, KV_PARTITION_SIZE=%d, kv_len=%d\n", partition_idx, KV_PARTITION_SIZE, kv_len); + if(kv_partition_idx * KV_PARTITION_SIZE > kv_len) { + // printf("WG exit: kv_partition_idx=%d, KV_PARTITION_SIZE=%d, kv_len=%d\n", kv_partition_idx, KV_PARTITION_SIZE, kv_len); return; } const uint total_blocks_num = (kv_len + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE; - - //#TODO: int8 compression data - uint kv_pitch = HEAD_SIZE * sizeof(half); - //# fp16 data - //# uint qo_pitch = HEADS_NUM * HEAD_SIZE * sizeof(half); + constexpr uint kv_pitch = HEAD_SIZE * sizeof(KV_ELEMENT_TYPE); //# Load Q into register(as dpas-A tile) - matrix Qmat; - uint qo_offset = (seq_idx*HEADS_NUM*q_len + head_num_idx)*HEAD_SIZE; - for(int k = 0, ri = 0; k < HEAD_SIZE; k += REG_K, ri++) { - cm_svm_block_read((svmptr_t)(query + qo_offset + k), Qmat[ri].format()); - } + const uint qo_offset = (seq_idx*HEADS_NUM*q_len + head_num_idx)*HEAD_SIZE; - //if(head_num_idx==0 && partition_idx==1) { - // printf("Qmat loaded, wg_thread_id=%d\n", wg_thread_id); + #if Q_RepeatCount != 1 + matrix Qmat = 0; + cm_svm_block_read((svmptr_t)(query + qo_offset), Qmat.format()); + #else + matrix Qmat = 0; + cm_svm_block_read((svmptr_t)(query + qo_offset), Qmat.format()); + #endif + + //if(kv_head_num_idx==0 && kv_partition_idx == 0) { + // printf("Qmat loaded, kv_head_num_idx=%d\n", kv_head_num_idx); // show(Qmat); //} - const uint per_kv_block_element_num = KV_BLOCK_SIZE * KV_HEADS_NUM * (HEAD_SIZE + KV_SCALE_ZP_SIZE / sizeof(half)); // 4: scale/zp + constexpr uint per_kv_block_element_num = KV_BLOCK_SIZE * KV_HEADS_NUM * (HEAD_SIZE + KV_SCALE_ZP_SIZE / sizeof(KV_ELEMENT_TYPE)); // 4 bytes: scale/zp uint block_num = KV_PARTITION_SIZE / KV_BLOCK_SIZE; - uint leftover_aligned_size = 0; uint leftover_size = 0; - if(partition_idx == kv_partition_num - 1) { - leftover_size = (kv_len - KV_PARTITION_SIZE * partition_idx) % KV_PARTITION_SIZE; - leftover_aligned_size = KV_STEP * ((leftover_size + KV_STEP - 1) / KV_STEP); // round up to KV_STEP + if(kv_partition_idx == kv_partition_num - 1) { + leftover_size = (kv_len - KV_PARTITION_SIZE * kv_partition_idx) % KV_PARTITION_SIZE; } if(block_num > total_blocks_num - start_block_idx) { block_num = total_blocks_num - start_block_idx; } //# rS = Q @ Kt - //# PARTITION_SUBBLOCK_NUM * [REG_M, REG_K] * [REG_K, REG_N] = PARTITION_SUBBLOCK_NUM * [REG_M, REG_N] - matrix rS = 0; - // # each WI can process multiple blocks + //# KV_PARTITION_STEP_NUM * [REG_M, REG_K] * [REG_K, REG_N] = KV_PARTITION_STEP_NUM * [REG_M, REG_N] + #if Q_RepeatCount != 1 + matrix rS = 0; + #else + matrix rS = 0; + #endif + // # Each SG can process multiple blocks + #pragma unroll for(uint block_idx = 0, ki = 0; block_idx < block_num; block_idx++) { uint blk_indices = block_indices[start_block_idx + block_idx]; uint kv_base_offset = blk_indices * per_kv_block_element_num + kv_head_num_idx * (per_kv_block_element_num / KV_HEADS_NUM); uint kv_scale_zp_offset = kv_base_offset + KV_BLOCK_SIZE * HEAD_SIZE; // scale/zp offset - // printf("seq_idx = %d, head_num_idx = %d, partition_idx = %d, start_block_idx = %d, block_idx = %d, blk_indices = %d, KV_PARTITION_SIZE = %d, KV_BLOCK_SIZE = %d, total_blocks_num = %d, seq_len = %d, kv_base_offset = %d\n", - // seq_idx, head_num_idx, partition_idx, start_block_idx, block_idx, blk_indices, KV_PARTITION_SIZE, KV_BLOCK_SIZE, total_blocks_num, seq_len, kv_base_offset); + // printf("seq_idx = %d, head_num_idx = %d, kv_partition_idx = %d, start_block_idx = %d, block_idx = %d, blk_indices = %d, KV_PARTITION_SIZE = %d, KV_BLOCK_SIZE = %d, total_blocks_num = %d, kv_pitch = %d, kv_base_offset = %d\n", + // seq_idx, head_num_idx, kv_partition_idx, start_block_idx, block_idx, blk_indices, KV_PARTITION_SIZE, KV_BLOCK_SIZE, total_blocks_num, kv_pitch, kv_base_offset); #if USE_LSC_BLOCK_2D_DESC - //# vector load cannot be used for block_2d_desc - //# note: candidate template ignored: deduced type 'details::Block2DRefTy' (aka 'vector_ref') of 1st parameter - //# b2dK reinterpret as 32bit(DWORD) for transposed load(combined with VNNI) - lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); - //printf("b2dK: kv_base_offset = %d, KV_BLOCK_SIZE = %d, HEAD_SIZE = %d, kv_pitch = %d, blk_indices = %d, block_idx = %d, start_block_idx = %d\n", - // kv_base_offset, KV_BLOCK_SIZE, HEAD_SIZE, kv_pitch, blk_indices, block_idx, start_block_idx); + #if KV_CACHE_COMPRESSION + // Transpose only support dword and qwork + lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + #else + lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); + #endif #else uint kv_offset = kv_base_offset; uint kv_stride = HEAD_SIZE; uint kv_x0 = 0, kv_y0 = 0; - uint kv_x1 = HEAD_SIZE*sizeof(half); + uint kv_x1 = HEAD_SIZE*sizeof(KV_ELEMENT_TYPE); uint kv_y1 = KV_BLOCK_SIZE; #endif uint kv_pos_end = KV_BLOCK_SIZE; - if(block_idx == block_num - 1 && leftover_aligned_size > 0) { + if(block_idx == block_num - 1 && leftover_size > 0) { kv_pos_end = leftover_size % KV_BLOCK_SIZE; + if(kv_pos_end == 0) kv_pos_end = KV_BLOCK_SIZE; } + + #if KV_CACHE_COMPRESSION + // load scale/zp + vector scale_vec; + vector zp_vec; + cm_svm_block_read(reinterpret_cast(key + kv_scale_zp_offset), scale_vec); + cm_svm_block_read(reinterpret_cast(key + kv_scale_zp_offset + KV_BLOCK_SIZE * sizeof(half)), zp_vec); + if(kv_pos_end < KV_BLOCK_SIZE) { + // fill leftover with last valid scale/zp + #pragma unroll + for(int i = kv_pos_end; i < KV_BLOCK_SIZE; i++) { + scale_vec[i] = 0.0; + zp_vec[i] = 0.0; + } + } + #endif + for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += KV_STEP, ki++) { - auto rSvec = rS[ki].format(); - uint kv_offset_y = kv_pos; + // auto rSvec = rS[ki].format(); + // uint kv_offset_y = kv_pos; + + #if KV_CACHE_COMPRESSION + vector temp_scale, temp_zp; + temp_scale.select(0) = scale_vec.select(kv_pos); + temp_scale.select(1) = scale_vec.select(kv_pos); + temp_zp.select(0) = zp_vec.select(kv_pos); + temp_zp.select(1) = zp_vec.select(kv_pos); + #endif #pragma unroll + #if KV_CACHE_COMPRESSION + for(int k = 0, ri = 0; k < HEAD_SIZE/4; k += REG_K/4, ri ++ ) { + #else for(int k = 0, ri = 0; k < HEAD_SIZE/2; k += REG_K/2, ri ++ ) { - matrix Kt; + #endif + matrix Kt = 0; #if USE_LSC_BLOCK_2D_DESC //# Load Kt into register & pack as VNNI(as dpas-B tile) //# DWORD transposed load == (transposed + VNNI) load b2dK.set_block_x(k); - cm_load(Kt.format(), b2dK.set_block_y(kv_offset_y)); + + #if KV_CACHE_COMPRESSION + // dequantize + matrix Kt_quant_temp, Kt_quant; + cm_load(Kt_quant_temp.format(), b2dK.set_block_y(kv_pos)); + auto quant_src = Kt_quant_temp.format(); + auto quant_dst = Kt_quant.format(); + + #pragma unroll + for(int r = 0; r < REG_K / 2; r += 2) { + quant_dst.row(r ) = quant_src.select<2,1,8,2>(r,0); + quant_dst.row(r+1) = quant_src.select<2,1,8,2>(r,1); + } + + #if DEBUG_ENABLE + printf("Kt_quant_temp: k = %d\n", k); + show_u8(Kt_quant_temp.format()); + printf("Kt_quant_vnni: k = %d\n", k); + show_u8(Kt_quant.format()); + #endif + + #pragma unroll + for(int r = 0; r < REG_K; r++) { + Kt[r] = Kt_quant[r] - temp_zp.format()[r%2]; //vector - vector + Kt[r] = cm_mul(Kt[r], temp_scale.format()[r%2]); // vector * vector + } + #else + cm_load(Kt.format(), b2dK.set_block_y(kv_pos)); + #endif + //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { + // printf("Kt: k = %d\n", k); + // show(Kt.format()); + //} #else matrix temp; - uint cur_kv_offset = kv_offset + kv_offset_y * kv_stride + k * 2;// uint --> half + uint cur_kv_offset = kv_offset + kv_pos * kv_stride + k * 2;// uint --> half #pragma unroll for(int kk = 0; kk < REG_N; kk++) { cm_svm_block_read((svmptr_t)(key + cur_kv_offset + kk * kv_stride), temp[kk].format()); @@ -200,34 +311,62 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( Transpose_8x8(temp.select<8,1,8,1>(8,0), Kt.format().select<8,1,8,1>(0,8)); #endif #endif - rSvec = cm_dpas( + #if Q_RepeatCount != 1 + matrix Qmat_data = Qmat.select(0, ri*REG_K); + matrix rS_data = 0; + for(int qi = 0; qi < Q_SLICE_NUM; qi ++) { + Qmat_data[qi] = Qmat[qi].format()[ri]; + } + rS_data = cm_dpas( + rS_data.format(), + Kt.format(), + Qmat_data.format()); + rS.select(0, ki*REG_N) += rS_data; + + //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { + // show(rS_data); + //} + #else + #pragma unroll + for(int qi = 0; qi < Q_SLICE_NUM; qi ++) { + auto Qmat_slice = Qmat[qi].format(); + auto rSvec = rS[qi].format()[ki].format(); + rSvec = cm_dpas( rSvec, Kt.format(), - Qmat[ri].format()); + Qmat_slice[ri].format()); + } + #endif } } } - // if(head_num_idx==0 && partition_idx==1) { - // printf("rS:\n"); - // show(rS); - // } + //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { + // printf("rS:\n"); + // show(rS); + //} // online softmax - float cur_sum = 0.0f; - float cur_lse = 0.0f; + vector cur_sum = 0.0f; + vector cur_lse = 0.0f; #if XE_ARCH==1 - matrix Pmat = 0; + matrix Pmat = 0; #else - matrix Pmat = 0; + #if Q_RepeatCount != 1 + matrix Pmat = 0; + #else + matrix Pmat = 0; + #endif #endif - { - rS = cm_mul(rS, (float)SCALE_FACTOR); // convert scale_factor into (float), or it will be promoted to double + #pragma unroll + for(int qi = 0; qi < Q_SLICE_NUM; qi++) { + auto rS_slice = rS[qi].format(); + rS_slice = cm_mul(rS_slice, (float)SCALE_FACTOR); // convert scale_factor into (float), or it will be promoted to double - // printf("leftover_size = %d, leftover_aligned_size = %d, XE_ARCH = %d, PARTITION_SUBBLOCK_NUM * REG_N = %d\n", leftover_size, leftover_aligned_size, XE_ARCH, PARTITION_SUBBLOCK_NUM * REG_N); + // printf("leftover_size = %d, leftover_aligned_size = %d, XE_ARCH = %d, KV_PARTITION_STEP_NUM * REG_N = %d\n", leftover_size, leftover_aligned_size, XE_ARCH, KV_PARTITION_STEP_NUM * REG_N); if(leftover_size > 0) { - auto Svec = rS.format(); - for(int i = leftover_size; i < PARTITION_SUBBLOCK_NUM * REG_N; i++){ + auto Svec = rS_slice.format(); + for(int i = leftover_size; i < KV_PARTITION_STEP_NUM * REG_N; i++){ Svec[i] = -3e38f; } } @@ -235,47 +374,46 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( // compute lse constexpr float log2e = 1.4426950408889634f; constexpr float loge2 = 0.6931471805599453f; - vector rS_exp = cm_exp(rS.format()*log2e); - float cur_lse_0 = cm_sum(rS_exp); - - //if(head_num_idx==0 && partition_idx==1) { - // uint lse_offset = seq_idx * HEADS_NUM * kv_partition_num + head_num_idx * kv_partition_num + wg_thread_id; - // printf("LSE[%d]: %f\n", lse_offset, cur_lse); - // printf("rS_exp:\n"); - // show(rS_exp.format()); - //} + vector rS_exp = cm_exp(rS_slice.format()*log2e); // compute row_max - auto rSv = rS.format(); + auto rSv = rS_slice.format(); float row_max = rSv[0]; + // It is performance hotspot for u8, must add unroll + #if KV_CACHE_COMPRESSION + #pragma unroll + #endif for(int r = 1; r < rSv.n_elems(); r++) row_max = cm_max(row_max, rSv[r]); - // compute P = exp(rS - row_max) + // compute P = exp(rS_slice - row_max) #if XE_ARCH==1 - Pmat= cm_exp((rS.format() - row_max)*log2e); + Pmat[qi].format() = cm_exp((rS_slice.format() - row_max)*log2e); #else - Pmat= cm_exp((rS - row_max)*log2e); + Pmat[qi].format() = cm_exp((rS_slice - row_max)*log2e); #endif - vector rS_exp_temp = cm_exp((rS.format() - row_max)*log2e); - cur_lse = cm_sum(rS_exp_temp.format()); - cur_lse = cm_log(cur_lse) * loge2 + row_max; // log2(sum(exp(x))) = log2e * log(sum(exp(x))) - //float cur_lse_1 = cm_exp(cur_lse * log2e); - //printf("row_max= %f, cur_lse =%f, cur_lse_0 = %f, cur_lse_1 = %f\n", row_max, cur_lse, cur_lse_0, cur_lse_1); + vector rS_exp_temp = cm_exp((rS_slice.format() - row_max)*log2e); + cur_lse[qi] = cm_sum(rS_exp_temp.format()); + cur_lse[qi] = cm_log(cur_lse[qi]) * loge2 + row_max; // log2(sum(exp(x))) = log2e * log(sum(exp(x))) // compute row sum of P - auto rPv = Pmat.format(); - cur_sum = cm_sum(rPv[0]); + auto rPv = Pmat[qi].format(); + cur_sum[qi] = cm_sum(rPv[0]); } - // if(wg_thread_id==kv_partition_num - 1 && head_num_idx == 0) { + //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { // printf("Pmat:\n"); // show(Pmat); //} //# rO = P * V - matrix Omat = 0; + #if Q_RepeatCount != 1 + matrix Omat = 0; + #else + matrix Omat = 0; + #endif + #pragma unroll for(uint block_idx = 0, ki = 0; block_idx < block_num; block_idx++) { uint blk_indices = block_indices[start_block_idx + block_idx]; uint kv_base_offset = blk_indices * per_kv_block_element_num + kv_head_num_idx * (per_kv_block_element_num / KV_HEADS_NUM); @@ -284,7 +422,11 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #if USE_LSC_BLOCK_2D_DESC //# vector load cannot be used for block_2d_desc //# note: candidate template ignored: deduced type 'details::Block2DRefTy' (aka 'vector_ref') of 1st parameter + #if KV_CACHE_COMPRESSION + lsc::block_2d_desc b2dV(value + kv_base_offset, KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + #else lsc::block_2d_desc b2dV(value + kv_base_offset, KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); + #endif #else uint kv_offset = kv_base_offset; uint kv_stride = HEAD_SIZE; @@ -293,22 +435,80 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint kv_y1 = KV_BLOCK_SIZE; #endif + //if(kv_partition_idx==kv_partition_num - 1 && head_num_idx == HEADS_NUM - 1) { + // printf("leftover_size = %d, leftover_aligned_size = %d, XE_ARCH = %d, KV_BLOCK_SIZE = %d\n", leftover_size, leftover_aligned_size, XE_ARCH, KV_BLOCK_SIZE); + //} uint kv_pos_end = KV_BLOCK_SIZE; - if(block_idx == block_num - 1 && leftover_aligned_size > 0) { + if(block_idx == block_num - 1 && leftover_size > 0) { kv_pos_end = leftover_size % KV_BLOCK_SIZE; + if(kv_pos_end == 0) kv_pos_end = KV_BLOCK_SIZE; + } + + #if KV_CACHE_COMPRESSION + // load scale/zp + vector scale_vec; + vector zp_vec; + cm_svm_block_read(reinterpret_cast(value + kv_scale_zp_offset), scale_vec); + cm_svm_block_read(reinterpret_cast(value + kv_scale_zp_offset + KV_BLOCK_SIZE * sizeof(half)), zp_vec); + if(kv_pos_end < KV_BLOCK_SIZE) { + // fill leftover with last valid scale/zp + for(int i = kv_pos_end; i < KV_BLOCK_SIZE; i++) { + scale_vec[i] = 0.0; + zp_vec[i] = 0.0; + } } - for(int kv_pos =0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { - uint kv_offset_y = kv_pos; + #endif + #pragma unroll + for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { + // uint kv_offset_y = kv_pos; + #if KV_CACHE_COMPRESSION + vector temp_scale = scale_vec.select(kv_pos); + vector temp_zp = zp_vec.select(kv_pos); + #endif #pragma unroll for(int k = 0, ri = 0; k < HEAD_SIZE; k += REG_N, ri ++ ) { // Load V into register & pack as VNNI(as dpas-B tile) + matrix VmatNormal; matrix Vmat; #if USE_LSC_BLOCK_2D_DESC b2dV.set_block_x(k); - cm_load(Vmat[0].format(), b2dV.set_block_y(kv_offset_y)); + #if KV_CACHE_COMPRESSION + // dequantize + matrix Vt_quant; + cm_load(Vt_quant.format(), b2dV.set_block_y(kv_pos)); + + #if DEBUG_ENABLE + //printf("Vt_quant: k = %d\n", k); + //show_u8(Vt_quant.format()); + //show(temp_scale); + //show(temp_zp); + //printf("\n"); + #endif + + #pragma unroll + for(int r = 0; r < REG_K; r++) { + VmatNormal[r] = Vt_quant[r] - temp_zp[r]; // vector - scalar + VmatNormal[r] = cm_mul(VmatNormal[r], temp_scale[r]); // vector * scalar + } + // show(VmatNormal.format()); + + if(kv_pos_end - kv_pos < KV_STEP) { + #pragma unroll + for(int r = kv_pos_end; r()); + #else + cm_load(Vmat[0].format(), b2dV.set_block_y(kv_pos)); + #endif + #if DEBUG_ENABLE + //printf("Vmat: k = %d\n", k); + //show(Vmat.format()); + #endif #else matrix temp; - uint cur_kv_offset = kv_offset + kv_offset_y * kv_stride + k; + uint cur_kv_offset = kv_offset + kv_pos * kv_stride + k; #pragma unroll for(int kk = 0; kk < REG_K; kk++) { cm_svm_block_read((svmptr_t)(value + cur_kv_offset + kk * kv_stride), temp[kk].format()); @@ -317,35 +517,54 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( Vref.select(0, 0) = temp.select(0, 0); Vref.select(0, 1) = temp.select(1, 0); #endif - Omat[ri] = cm_dpas( - Omat[ri], + #if Q_RepeatCount != 1 + matrix Pmat_data = Pmat.select(0, ki*REG_K); + matrix Omat_data = 0; + Omat_data = cm_dpas( + Omat_data.format(), + Vmat[0].format(), + Pmat_data.format()); + Omat.select(0, ri*REG_N) += Omat_data; + #else + for(int qi = 0; qi < Q_SLICE_NUM; qi ++) { + auto Pmat_slice = Pmat[qi].format(); + auto Omat_slice = Omat[qi].format(); + Omat_slice[ri] = cm_dpas( + Omat_slice[ri], Vmat[0].format(), - Pmat[ki].format()); + Pmat_slice[ki].format()); + } + #endif + //if(kv_partition_idx==kv_partition_num - 1 && head_num_idx == 27) { + // printf("Omat[%d][%d]:\n",kv_pos, k); + // show(Omat); + //} } } } - // if(wg_thread_id==kv_partition_num - 1) { + //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { // printf("Omat:\n"); // show(Omat); - // } + //} //# save Output - matrix cur_O; - uint o_offset = seq_idx * kv_partition_num * HEADS_NUM * HEAD_SIZE + kv_partition_num * head_num_idx * HEAD_SIZE + wg_thread_id * HEAD_SIZE; - float div_cur_sum = 1.0/cur_sum; - #pragma unroll - for(int k = 0, ri=0; k < HEAD_SIZE; k += REG_N, ri++) { - auto cO = Omat[ri].format(); - #if XE_ARCH==1 - cur_O= cm_mul(cO, div_cur_sum); - #else - cur_O= cm_div_ieee(cO, cur_sum); - #endif - cm_svm_block_write((svmptr_t)(output + o_offset + k),cur_O.format()); + for (int qi = 0; qi < Q_SLICE_NUM; qi++) { + matrix cur_O_f32; + uint o_offset = seq_idx * kv_partition_num * KV_HEADS_NUM * HEAD_SIZE + kv_partition_num * (head_num_idx + qi) * HEAD_SIZE + wg_thread_id * HEAD_SIZE; + float div_cur_sum = 1.0/cur_sum[qi]; + auto Omat_slice = Omat[qi].format(); + #pragma unroll + for(int k = 0, ri=0; k < HEAD_SIZE; k += REG_N, ri++) { + auto cO = Omat_slice[ri].format(); + #if XE_ARCH==1 + cur_O_f32= cm_mul(cO, div_cur_sum); + #else + cur_O_f32= cm_div_ieee(cO, cur_sum[qi]); + #endif + cm_svm_block_write((svmptr_t)(output + o_offset + k),cur_O_f32.format()); + } + uint lse_offset = seq_idx * KV_HEADS_NUM * kv_partition_num + (head_num_idx + qi) * kv_partition_num + wg_thread_id; + lse[lse_offset] = cur_lse[qi]; } - uint lse_offset = seq_idx * HEADS_NUM * kv_partition_num + head_num_idx * kv_partition_num + wg_thread_id; - lse[lse_offset] = cur_lse; - - // printf("LSE[%d]: %f\n", lse_offset, cur_lse); } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index d5ea6f16516692..590757a934dd13 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -509,6 +509,8 @@ JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_ jit.make("KV_HEADS_NUM", desc->kv_heads_num); jit.make("Q_STEP", get_q_step(xe_arch, true)); + jit.make("KV_CACHE_COMPRESSION", 0); + return jit; } @@ -544,13 +546,14 @@ DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() co auto& wgs = kd.params.workGroups; const auto desc = params.typed_desc(); auto rtp = static_cast(rt_params); - assert(rt_params != nullptr); const size_t batch = params.input_layouts[0].get_partial_shape()[0].get_length(); const size_t heads_num = desc->heads_num; + const size_t kv_heads_num = desc->kv_heads_num; const size_t partition_num = rtp->num_of_partitions; // get_partition_num(rtp->max_context_len); - wgs.global = {batch, heads_num, partition_num}; + + wgs.global = {batch, kv_heads_num, partition_num}; wgs.local = {1, 1, 1}; // generate stage: q_len=1 @@ -626,7 +629,7 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da scalars.resize(scaler_value.size()); if (DEBUG_ENABLED) { // Debug - std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " + std::cout << "PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_data_func: " << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 497d6d89b42a22..0657488e706f4e 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -952,7 +952,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { public: random_generator rg; cldnn::engine& engine = get_test_engine(); - float tolerance = 2e-2; + float tolerance = 2e-3; void SetUp() override { rg.set_seed(GET_SUITE_NAME); @@ -1258,6 +1258,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{1024, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long paged_attention_test_params{ {{1, 31}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 32}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 1023}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 127}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 129}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token From 95a2da1ada8c87617febe642f475e8bc2eb07afe Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 15 Sep 2025 09:21:05 +0800 Subject: [PATCH 19/96] fix waring_as_error on CI Windows. --- .../src/graph/impls/cm/paged_attention.cpp | 8 ++++---- .../src/graph/impls/cm/paged_attention_gen.cpp | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 76b77570103902..08c02ec405cf21 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -157,8 +157,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto out_shape = params.output_layouts[0].get_shape(); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint32_t M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` - const uint32_t N = kv_len / STRIDE; + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t N = static_cast(kv_len / STRIDE); const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); @@ -169,8 +169,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { if (block_size > 1) { OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); const uint32_t q_block_pad = ceil_div(q_len, block_size); - const uint32_t sum_per_token_in_block = block_size / STRIDE; - const uint32_t k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); + const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); const uint32_t k_block_pad = k_block_in_group * N_kq_groups; auto count_kq_exp_partial_sum = static_cast(desc->heads_num * q_stride_pad * k_block_pad); internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 590757a934dd13..5f505dc3e329f1 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -757,8 +757,8 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { auto out_shape = params.output_layouts[0].get_shape(); const size_t q_len = out_shape[0]; - const uint32_t M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` - const uint32_t N = kv_len / STRIDE; + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t N = static_cast(kv_len / STRIDE); const uint32_t K = STRIDE * head_size; auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; @@ -858,7 +858,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { OPENVINO_ASSERT(wg_k % block_size == 0, "wg_k should be multiple of block_size then there is no tails from block_size"); OPENVINO_ASSERT(wg_q % block_size == 0, "wg_q should be multiple of block_size then there is no tails from block_size"); - const size_t sum_per_n_token_in_block = block_size / STRIDE; + const size_t sum_per_n_token_in_block = static_cast(block_size / STRIDE); // const size_t batch = params.input_layouts[PagedAttentionInputIdx::QUERY].get_partial_shape()[0].get_length(); const size_t heads_num = desc->heads_num; @@ -867,15 +867,15 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { auto out_shape = params.output_layouts[0].get_shape(); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint32_t M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` - const uint32_t N = kv_len / STRIDE; + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t N = static_cast(kv_len / STRIDE); const uint32_t q_stride = M; const uint32_t k_stride = N; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); - const uint32_t sum_per_token_in_block = block_size / STRIDE; - const uint32_t k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); + const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); const uint32_t k_block_pad = k_block_in_group * N_kq_groups; const uint32_t q_block_pad = ceil_div(q_len, block_size); From 36bee720dfc692a5a0c10451eba880849c8f9ff9 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 15 Sep 2025 13:27:58 +0800 Subject: [PATCH 20/96] dump block mask with DUMP_XATTN_BLOCK_MASK for debug --- .../src/graph/impls/cm/paged_attention.cpp | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 08c02ec405cf21..9e07710643671c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -20,6 +20,11 @@ #include "paged_attention_inst.h" #include "primitive_inst.h" +#define DUMP_XATTN_BLOCK_MASK 0 +#if DUMP_XATTN_BLOCK_MASK +#include "openvino/util/file_util.hpp" +#endif + namespace ov::intel_gpu::cm { class PagedAttentionCmImpl : public PrimitiveImplCM { @@ -97,8 +102,28 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { // stream.finish(); // std::cout << "finish xattn_estimate_gemmqk!\n"; res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; - // stream.finish(); - // std::cout << "finish xattn_estimate_find_block!\n"; +#if DUMP_XATTN_BLOCK_MASK + { + cldnn::stream& stream = instance.get_network().get_stream(); + stream.finish(); + static uint32_t pa_id = 0; + std::cout << "finish xattn_estimate_find_block!\n"; + auto output_mem = instance.get_intermediates_memories()[4]; + mem_lock lock(output_mem, stream); + auto& layout = output_mem->get_layout(); + std::string data_type = ov::element::Type(layout.data_type).get_type_name(); + std::string format = layout.format.to_string(); + std::string tensor; + auto dims = layout.get_dims(); + for (size_t r = 0 ; r < layout.get_rank() ; r++) { + tensor += ("_" + to_string(dims[r])); + } + // std::string filename = "PA" + std::to_string(pa_id) + "__" + data_type + "_" + tensor + "__" + format + ".bin"; + std::string filename = "PA" + std::to_string(pa_id) + ".bin"; + ov::util::save_binary(filename, lock.data(), output_mem->size()); + pa_id++; + } +#endif } res_event = {execute_stage(res_event, instance, pa_multi_token)}; } else if (rt_params->stage == PagedAttentionStage::GENERATE) { From 4fa97bec52717376df078df63114467ee39e7daf Mon Sep 17 00:00:00 2001 From: "river.li" Date: Sun, 14 Sep 2025 16:35:01 +0800 Subject: [PATCH 21/96] Support kv cache u8 precision --- .../convert_pagedattn_inputs.cpp | 2 +- .../graph/impls/cm/include/cm_sdpa_common.hpp | 880 +++++++++--------- .../graph/impls/cm/pa_kv_cache_update_ref.cm | 40 +- .../src/graph/impls/cm/pa_multi_token.cm | 101 +- .../src/graph/impls/cm/pa_single_token.cm | 7 +- .../src/graph/impls/cm/paged_attention.hpp | 2 +- .../graph/impls/cm/paged_attention_gen.cpp | 45 +- .../impls/ocl_v2/sdpa/paged_attention_opt.hpp | 2 +- 8 files changed, 574 insertions(+), 505 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp index 636c20ed609614..7935d905a3b638 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp @@ -107,7 +107,7 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co value_cache->set_element_type(value_cache_precision); bool status = false; if (pa_op->get_rt_info().count("num_k_heads") && pa_op->get_rt_info().count("k_head_size") && - pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("num_v_heads")) { + pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("v_head_size")) { const auto key_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_k_heads"].as(), pa_op->get_rt_info()["k_head_size"].as(), m_config.keyCacheBlockSize, diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp index b4f679d198f4a2..c8a3ed83921bf8 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp @@ -18,6 +18,8 @@ //# CM-compiler is C++17 static_assert(__cplusplus >= 201703L); +//# static_assert(__cplusplus >= 202002L); +//# static_assert(__cplusplus >= 202302L); #define SystolicDepth 8 #define RepeatCount 8 @@ -38,19 +40,29 @@ static_assert(q_step == 16 || q_step == 8); static_assert(kv_step == 16); static_assert(CM_HAS_DPAS); +#define DEBUG_SHOW 1 +#if !DEBUG_SHOW template -void show(const matrix mat) { +void show(const matrix mat, bool isfloat=true) { +} +#else +template +void show(const matrix mat, bool isfloat=true) { printf("Matrix [%d, %d]:\n", M, N); for(int m = 0; m < M; m ++) { printf("\t["); for(int n = 0; n < N; n ++) { - printf("%8.4f,", mat[m][n]); + if (isfloat) + printf("%8.4f,", mat[m][n]); + else + printf("%8d,", mat[m][n]); + } printf("],\n"); } printf("]\n"); } - +#endif template CM_INLINE void Transpose_16x16(matrix_ref in, matrix_ref out) { @@ -218,6 +230,9 @@ inline matrix ugemm_KQ(uint slm_K, matrix_ref Kmat; cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + // if (cm_local_id(2) == 3 && cm_group_id(2) == 0) { + // show(Kmat.format()); + // } #pragma unroll for(int k = 0; k < num_K; k++) St2.row(k) = cm_dpas(0, Qt[0].format(), Kmat[k].format()); @@ -242,7 +257,9 @@ inline void ugemm_PV0(uint slm_V, matrix_ref P, matrix_ref Vmat; cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); - + // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { + // show(Vmat.format()); + // } #pragma unroll for(int p = 0; p < num_P_tiles; p++) { rO[ri + p] = cm_dpas( @@ -262,10 +279,11 @@ inline void ugemm_PV1(uint slm_V, matrix_ref P, vector_ref Vmat; - cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); - //# compensate cur_O - // matrix rO; + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); + // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { + // show(Vmat.format()); + // } #pragma unroll for(int p = 0; p < num_P_tiles; p++) { auto cO = rO[ri + p].format(); @@ -276,7 +294,6 @@ inline void ugemm_PV1(uint slm_V, matrix_ref P, vector_ref()); - //# show(cur_O.format()); return; #pragma unroll for(int p = 0; p < num_P_tiles; p++) { rO[ri + p] = cm_dpas( @@ -378,8 +395,6 @@ vector online_softmax_update(matrix_ref St, vector_r } #endif - - //=============================================================================================== template constexpr void apply_causal_mask(matrix_ref St) { @@ -389,10 +404,268 @@ constexpr void apply_causal_mask(matrix_ref St) { } } -#ifdef CM_HAS_LSC_UNTYPED_2D +//prepack [K, N] to [K/2, N, 2] layout. +template +inline void prepackAsVNNIWidth2(matrix_ref input, matrix_ref out) { + #pragma unroll + for (int r = 0; r < K/2; r++) { + out.row(r).select(0) = input.row(r*2); + out.row(r).select(1) = input.row(r*2+1); + } +} -template -void sdpa_kernel_lsc( +//@prefetch_u8 would have duplicated decompress perf issue. comments out for now. +// template +// void sdpa_kernel_lsc_prefetch_u8( +// int wg_local_id, +// int q_start, +// int kv_stop, // +// int q_len, //q_step +// int kv_len, //not used for now +// svmptr_t q_base [[type("svmptr_t")]], +// svmptr_t k_cache_base [[type("svmptr_t")]], +// svmptr_t v_cache_base [[type("svmptr_t")]], +// svmptr_t o_base [[type("svmptr_t")]], +// int32_t past_lens, +// int32_t* block_indices [[type("svmptr_t")]]) { +// constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); +// constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; +// //[block_num, kv_heads, block_size, head_size] +// constexpr uint kv_pitch = head_size * sizeof(uint8_t); + +// vector cur_max; +// vector cur_sum; + +// cur_max = -3e38f; +// cur_sum = 0; +// constexpr int num_P_tiles = REG_N / REG_M; +// matrix rQ; +// matrix rO; + +// auto q_tokens_left = q_len;// - q_start; +// static_assert(q_step == REG_N); +// static_assert(kv_step == REG_K); + +// if (q_tokens_left < 0) q_tokens_left = 0; +// if (q_tokens_left > q_step) q_tokens_left = q_step; + +// if (q_tokens_left > 0) { +// lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); +// #pragma unroll +// for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { +// cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); +// rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); +// } +// } + +// lsc::block_2d_desc b2dKV(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + +// static_assert(wg_local_size == 16); +// lsc::block_2d_desc b2dKV_prefetch(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); +// // constexpr int blk_stride = CMFLA_NUM_KV_HEADS * CMFLA_HEAD_SIZE * CMPA_BLOCK_SZ; +// constexpr int quan_blk_stride = CMFLA_NUM_KV_HEADS * (CMFLA_HEAD_SIZE+4) * CMPA_BLOCK_SZ * sizeof(uint8_t); + + + +// int causal_left = q_start+past_lens; +// for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) { +// auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; +// //For the last step, duplicate prefetch here. +// uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); +// auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; +// uint32_t dscale_offset = cur_block_id*quan_blk_stride + CMPA_BLOCK_SZ * head_size * sizeof(uint8_t) + kv_pos%CMPA_BLOCK_SZ*sizeof(half); + +// vector dscale; +// vector zp; +// cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset), dscale); +// cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset + CMPA_BLOCK_SZ*sizeof(half)), zp); + +// // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { +// // show(dscale.format()); +// // } +// //# St = k @ Qt + +// matrix St; // = ugemm_KQ(slm_K, rQ, slm_offset); +// { +// constexpr int num_K = kv_step/REG_M; +// auto St2 = St.format(); +// matrix Kmat; +// auto quan_Kmat = Kmat.format().row(1).format(); +// auto dq_Kmat = Kmat.format(); +// //cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + +// b2dKV_prefetch.set_base_ptr(reinterpret_cast(k_cache_base+prefetch_block_id*quan_blk_stride)); +// b2dKV_prefetch.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); +// cm_prefetch(b2dKV_prefetch.set_block_x(0)); + +// b2dKV.set_base_ptr(reinterpret_cast(k_cache_base+cur_block_id*quan_blk_stride)); +// b2dKV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + +// cm_load(quan_Kmat.format(), b2dKV.set_block_x(0)); +// // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { +// // show(quan_Kmat.format(), false); +// // } +// #pragma unroll +// for(int r = 0; r < kv_step; r++) { +// dq_Kmat[r] = quan_Kmat[r] - zp[r]; +// dq_Kmat[r] = cm_mul(dq_Kmat[r], dscale[r]); +// } + +// #pragma unroll +// for(int k = 0; k < num_K; k++) +// St2.row(k) = cm_dpas( +// 0, +// rQ[0].format(), +// Kmat[k].format()); + +// #pragma unroll +// for(int ri = 1; ri < head_size/REG_K; ri++) { +// cm_prefetch(b2dKV_prefetch.set_block_x(ri*REG_K)); +// //cm_load(Kmat.format(), b2dKV.set_block_x(ri*REG_K)); +// cm_load(quan_Kmat.format(), b2dKV.set_block_x(ri*REG_K)); +// #pragma unroll +// for(int r = 0; r < kv_step; r++) { +// dq_Kmat[r] = quan_Kmat[r] - zp[r]; +// dq_Kmat[r] = cm_mul(dq_Kmat[r], dscale[r]); +// } +// #pragma unroll +// for(int k = 0; k < num_K; k++) { +// St2.row(k) = cm_dpas( +// St2.row(k), +// rQ[ri].format(), +// Kmat[k].format()); +// } +// } +// } +// if constexpr (use_causal_mask) { +// // since kv_step == q_step == 16, causal_left is n*kv_step +// if (causal_left == 0) { +// apply_causal_mask<1>(St); +// } else if (causal_left < 0) { +// St = -3.4e38f; +// } +// causal_left -= kv_step; +// } else { +// int kv_tokens = kv_stop - kv_pos; +// // LSC ensures no overflow-access, but mask off k-tails attn-score is still required +// for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; +// } + +// //show(St); +// auto max_comp = online_softmax_update(St, cur_max, cur_sum); + +// matrix P; +// Transpose2DMatrix(St, P); + +// b2dKV_prefetch.set_base_ptr(reinterpret_cast(v_cache_base+prefetch_block_id*quan_blk_stride)); +// b2dKV_prefetch.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + +// b2dKV.set_base_ptr(reinterpret_cast(v_cache_base+cur_block_id*quan_blk_stride)); +// b2dKV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + + +// cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset), dscale); +// cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset+CMPA_BLOCK_SZ*sizeof(half)), zp); + +// { +// matrix VmatVNNI2; +// matrix Vmat; +// auto quanVmat = Vmat.format().row(1).format(); +// int kv_tokens = kv_stop - kv_pos; +// if (kv_pos == 0) { +// // ugemm_PV0(slm_V, P, rO, slm_offset); +// auto P2 = P.format(); +// #pragma unroll +// for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { +// cm_prefetch(b2dKV_prefetch.set_block_x(k)); +// cm_load(quanVmat.format(), b2dKV.set_block_x(k)); +// #pragma unroll +// for(int r = 0; r < kv_step;r++) { +// Vmat[r] = quanVmat[r]-zp[r]; +// Vmat[r] = cm_mul(Vmat[r], dscale[r]); +// } +// for(int r = kv_step-1; r>=kv_tokens;r--) { +// Vmat[r] = 0; +// } + +// prepackAsVNNIWidth2(Vmat, VmatVNNI2); + +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// rO[ri + p] = cm_dpas( +// 0, +// VmatVNNI2.format(), +// P2.row(p).format()); +// } +// } +// } +// else { +// //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); +// auto P2 = P.format(); +// #pragma unroll +// for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { +// cm_prefetch(b2dKV_prefetch.set_block_x(k)); +// cm_load(quanVmat.format(), b2dKV.set_block_x(k)); + +// #pragma unroll +// for(int r = 0; r < kv_step;r++) { +// Vmat[r] = quanVmat[r]-zp[r]; +// Vmat[r] = cm_mul(Vmat[r], dscale[r]); +// } +// for(int r = kv_step-1; r>=kv_tokens;r--) { +// Vmat[r] = 0; +// } + +// prepackAsVNNIWidth2(Vmat, VmatVNNI2); +// //# compensate cur_O +// // matrix rO; +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// auto cO = rO[ri + p].format(); +// #pragma unroll +// for(int r = 0; r < REG_M; r++) +// cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); +// } + +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// rO[ri + p] = cm_dpas( +// rO[ri + p].format(), +// VmatVNNI2.format(), +// P2.row(p).format()); +// } +// } +// } +// } +// } +// if (q_tokens_left == 0) return; + +// //# save cur_O/cur_sum.transpose(0, 1) +// matrix cur_O_f16; +// cur_sum = cm_inv(cur_sum); + +// lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); + +// #pragma unroll +// for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// auto cO = rO[ri + p].format(); +// #pragma unroll +// for(int r = 0; r < cO.n_rows(); r++) { +// cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + +// } +// } +// b2dO.set_block_x(k); +// cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); +// cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); +// } +// } + +#if CMPA_KVCACHE_U8 +template +void pa_lsc_u8( uint slm_K, uint slm_V, int wg_local_id, @@ -402,13 +675,20 @@ void sdpa_kernel_lsc( int q_len, int kv_len, svmptr_t q_base [[type("svmptr_t")]], - svmptr_t k_base [[type("svmptr_t")]], - svmptr_t v_base [[type("svmptr_t")]], - svmptr_t o_base [[type("svmptr_t")]]) { + svmptr_t k_cache_base [[type("svmptr_t")]], + svmptr_t v_cache_base [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + svmptr_t sparse_mask_base [[type("svmptr_t")]], + svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], +#endif + svmptr_t o_base [[type("svmptr_t")]], + int32_t past_lens, + int32_t* block_indices [[type("svmptr_t")]]) { constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); - constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; - constexpr uint kv_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_q_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + //[block_num, kv_heads, block_size, head_size] + constexpr uint kv_pitch = head_size * sizeof(uint8_t); vector cur_max; vector cur_sum; @@ -435,71 +715,152 @@ void sdpa_kernel_lsc( } } - lsc::block_2d_desc b2dK(k_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); - lsc::block_2d_desc b2dV(v_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); - - int causal_left = q_start; + lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + constexpr int quan_blk_stride = CMFLA_NUM_KV_HEADS * (CMFLA_HEAD_SIZE+4) * CMPA_BLOCK_SZ * sizeof(uint8_t); + int causal_left = q_start+past_lens; constexpr uint slm_buff_size = kv_step * head_size * sizeof(half); int slm_buff_id_write = 0; int slm_buff_id_read = 0; +#if SPARSE_BLOCK_SIZE > 1 + auto skip_compute = [&](int kv_pos) { + auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); + + return !sparse_mask; + }; + auto skip_load = [&](int kv_pos) { + auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(wg_sparse_mask_base) + kv_start_block); + return !sparse_mask; + }; +#endif + auto load_slm_KV = [&](int kv_pos) { if (kv_pos < kv_stop) { +#if SPARSE_BLOCK_SIZE > 1 + if (skip_load(kv_pos)) { + slm_buff_id_write++; + return; + } +#endif + auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; + uint32_t dscale_offset = cur_block_id*quan_blk_stride + \ + CMPA_BLOCK_SZ * head_size * sizeof(uint8_t) + kv_pos%CMPA_BLOCK_SZ*sizeof(half); + uint slm_offset = (slm_buff_id_write & 3) * slm_buff_size; + vector dscale; + vector zp; + int kv_left = (kv_stop-kv_pos) > kv_step ? kv_step: (kv_stop-kv_pos); + slm_buff_id_write ++; if (wg_local_id < local_size/2) { - vector temp0; - b2dK.set_block_y(kv_pos); + cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset), dscale); + cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset + CMPA_BLOCK_SZ*sizeof(half)), zp); + + matrix kmat; + auto quanKmat = kmat.format()[1].format(); + b2dK.set_base_ptr(reinterpret_cast(k_cache_base+cur_block_id*quan_blk_stride)); + b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); + for(int k = REG_K*wg_local_id; k < head_size; k += REG_K*(local_size/2)) { - cm_load(temp0, b2dK.set_block_x(k)); - cm_slm_block_write(slm_K, slm_offset + k * kv_step * sizeof(half), temp0); + cm_load(quanKmat.format(), b2dK.set_block_x(k)); + /*@bug: cm compiler in the tail process. + : loop combined with type convert. + for(int r = 0; r < kv_left; r++) { + kmat[r] = quanKmat[r]-zp[r]; + kmat[r] = cm_mul(kmat[r], dscale[r]); + } + wa: unroll all kv_step rows. set 0 to padding rows. + */ + #pragma unroll + for(int r = 0; r < kv_step; r++) { + kmat[r] = quanKmat[r]-zp[r]; + kmat[r] = cm_mul(kmat[r], dscale[r]); + } + //clear unused data to 0. + for(int r = kv_step-1; r >= kv_left; r--) + kmat[r] = 0; + cm_slm_block_write(slm_K, slm_offset + k * kv_step * sizeof(half), kmat.format()); } } else { - vector temp2; - b2dV.set_block_y(kv_pos); + cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset), dscale); + cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset+CMPA_BLOCK_SZ*sizeof(half)), zp); + + matrix VmatVNNI; + matrix Vmat; + auto quanVmat = Vmat.format().row(1).format(); + b2dV.set_base_ptr(reinterpret_cast(v_cache_base+cur_block_id*quan_blk_stride)); + b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + #pragma unroll for(int k = REG_N*(wg_local_id-(local_size/2)); k < head_size; k += REG_N*(local_size/2)) { - cm_load(temp2, b2dV.set_block_x(k)); - cm_slm_block_write(slm_V, slm_offset + k * REG_K * sizeof(half), temp2); + cm_load(quanVmat.format(), b2dV.set_block_x(k)); + /*@bug: cm compiler in the tail process. + : loop combined with type convert. + for(int r = 0; r < kv_left; r++) { + Vmat[r] = quanVmat[r]-zp[r]; + Vmat[r] = cm_mul(Vmat[r], dscale[r]); + } + */ + #pragma unroll + for(int r = 0; r < kv_step;r++) { + Vmat[r] = quanVmat[r]-zp[r]; + Vmat[r] = cm_mul(Vmat[r], dscale[r]); + } + + for(int r = kv_step-1; r>=kv_left;r--) { + Vmat[r] = 0; + } + prepackAsVNNIWidth2(Vmat, VmatVNNI); + cm_slm_block_write(slm_V, slm_offset + k * REG_K * sizeof(half), VmatVNNI.format()); } } } }; + load_slm_KV(0); load_slm_KV(kv_step); cm_slm_fence(CM_LOCAL_BARRIER); cm_sbarrier(1); - for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step, - k_base += kv_step * kv_pitch, - v_base += kv_step * kv_pitch, - slm_buff_id_read ++) { + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step,slm_buff_id_read++) { - // load0, load1, signal1, - // [wait2, signal2, load2, read0] - // [wait3, signal3, load3, read1] - // [wait4, signal4, load4, read2] - // [wait5, signal5, load5, read3] + // load0, load1, signal1, + // [wait1, signal2, load2, read0, compute0] + // [wait2, signal3, load3, read1, compute1] + // [wait3, signal4, load4, read2, compute2] + // [wait4, signal5, load5, read3, compute3] // - // after wait4, all workers have reached signal3, so: - // - all workers have finished load2 & read0. - // - we can start to load 4 into SLM slot 0 (i & 3) safely + // after wait3, all workers have reached signal3, so: + // - all workers have finished load2 & read0. + // - we can start to load 4 into SLM slot 0 (i & 3) safely // - we can start to read 2 ((i-2) & 3) safely + cm_fence(CM_LOCAL_BARRIER); cm_sbarrier(0); - //if (kv_pos > 1024000) // for debugging + //if (kv_pos > 1024000) if (kv_pos + kv_step < kv_stop) cm_sbarrier(1); - load_slm_KV(kv_pos + kv_step*2); + +#if SPARSE_BLOCK_SIZE > 1 + if (skip_compute(kv_pos)) { + if constexpr (use_causal_mask) + causal_left -= kv_step; + continue; + } +#endif { + uint slm_offset = (slm_buff_id_read & 3) * slm_buff_size; + //# St = k @ Qt matrix St = ugemm_KQ(slm_K, rQ, slm_offset); - if constexpr (use_causal_mask) { // since kv_step == q_step == 16, causal_left is n*kv_step if (causal_left == 0) { @@ -513,8 +874,6 @@ void sdpa_kernel_lsc( // LSC ensures no overflow-access, but mask off k-tails attn-score is still required for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; } - - //show(St); auto max_comp = online_softmax_update(St, cur_max, cur_sum); matrix P; @@ -551,198 +910,21 @@ void sdpa_kernel_lsc( } } -template -void sdpa_kernel_lsc_prefetch( - int wg_local_id, - int q_start, - int kv_stop, - int q_len, - int kv_len, - svmptr_t q_base [[type("svmptr_t")]], - svmptr_t k_base [[type("svmptr_t")]], - svmptr_t v_base [[type("svmptr_t")]], - svmptr_t o_base [[type("svmptr_t")]]) { - - constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); - constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; - constexpr uint kv_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); - - vector cur_max; - vector cur_sum; - - cur_max = -3e38f; - cur_sum = 0; - constexpr int num_P_tiles = REG_N / REG_M; - matrix rQ; - matrix rO; - - auto q_tokens_left = q_len;// - q_start; - static_assert(q_step == REG_N); - static_assert(kv_step == REG_K); - - if (q_tokens_left < 0) q_tokens_left = 0; - if (q_tokens_left > q_step) q_tokens_left = q_step; - - if (q_tokens_left > 0) { - lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); - #pragma unroll - for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { - cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); - rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); - } - } - - lsc::block_2d_desc b2dK(k_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); - lsc::block_2d_desc b2dV(v_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); - - static_assert(wg_local_size == 16); - lsc::block_2d_desc prefetch_K(k_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); - lsc::block_2d_desc prefetch_V(v_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); - - int causal_left = q_start; - - for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step, - k_base += kv_step * kv_pitch, - v_base += kv_step * kv_pitch) { - //# St = k @ Qt - matrix St; // = ugemm_KQ(slm_K, rQ, slm_offset); - { - constexpr int num_K = kv_step/REG_M; - auto St2 = St.format(); - - matrix Kmat; - //cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); - prefetch_K.set_block_y(wg_local_id + kv_pos + kv_step); - cm_prefetch(prefetch_K.set_block_x(0)); - - b2dK.set_block_y(kv_pos); - cm_load(Kmat.format(), b2dK.set_block_x(0)); - #pragma unroll - for(int k = 0; k < num_K; k++) - St2.row(k) = cm_dpas( - 0, - rQ[0].format(), - Kmat[k].format()); - - #pragma unroll - for(int ri = 1; ri < head_size/REG_K; ri++) { - //cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); - cm_prefetch(prefetch_K.set_block_x(ri*REG_K)); - cm_load(Kmat.format(), b2dK.set_block_x(ri*REG_K)); - #pragma unroll - for(int k = 0; k < num_K; k++) { - St2.row(k) = cm_dpas( - St2.row(k), - rQ[ri].format(), - Kmat[k].format()); - } - } - } - if constexpr (use_causal_mask) { - // since kv_step == q_step == 16, causal_left is n*kv_step - if (causal_left == 0) { - apply_causal_mask<1>(St); - } else if (causal_left < 0) { - St = -3.4e38f; - } - causal_left -= kv_step; - } else { - int kv_tokens = kv_stop - kv_pos; - // LSC ensures no overflow-access, but mask off k-tails attn-score is still required - for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; - } - - //show(St); - auto max_comp = online_softmax_update(St, cur_max, cur_sum); - - matrix P; - Transpose2DMatrix(St, P); - - b2dV.set_block_y(kv_pos); - prefetch_V.set_block_y(wg_local_id +kv_pos + kv_step); - if (kv_pos == 0) { - // ugemm_PV0(slm_V, P, rO, slm_offset); - auto P2 = P.format(); - #pragma unroll - for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { - matrix Vmat; - cm_prefetch(prefetch_V.set_block_x(k)); - cm_load(Vmat.format(), b2dV.set_block_x(k)); - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - rO[ri + p] = cm_dpas( - 0, - Vmat.format(), - P2.row(p).format()); - } - } - } - else { - //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); - auto P2 = P.format(); - #pragma unroll - for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { - matrix Vmat; - - cm_prefetch(prefetch_V.set_block_x(k)); - cm_load(Vmat.format(), b2dV.set_block_x(k)); - - //# compensate cur_O - // matrix rO; - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - auto cO = rO[ri + p].format(); - #pragma unroll - for(int r = 0; r < REG_M; r++) - cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); - } - - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - rO[ri + p] = cm_dpas( - rO[ri + p].format(), - Vmat.format(), - P2.row(p).format()); - } - } - } - } - if (q_tokens_left == 0) return; - - //# save cur_O/cur_sum.transpose(0, 1) - matrix cur_O_f16; - cur_sum = cm_inv(cur_sum); - - lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); - - #pragma unroll - for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - auto cO = rO[ri + p].format(); - #pragma unroll - for(int r = 0; r < cO.n_rows(); r++) { - cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); - } - } - b2dO.set_block_x(k); - cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); - cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); - } -} +#else template -void pa_kernel_lsc_prefetch( +void pa_kernel_lsc_prefetch_f16( int wg_local_id, int q_start, int kv_stop, // int q_len, //q_step int kv_len, //not used for now svmptr_t q_base [[type("svmptr_t")]], - svmptr_t k_base [[type("svmptr_t")]], - svmptr_t v_base [[type("svmptr_t")]], + svmptr_t k_cache_base [[type("svmptr_t")]], + svmptr_t v_cache_base [[type("svmptr_t")]], #if SPARSE_BLOCK_SIZE > 1 svmptr_t sparse_mask_base [[type("svmptr_t")]], + svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], #endif svmptr_t o_base [[type("svmptr_t")]], int32_t past_lens, @@ -758,25 +940,19 @@ void pa_kernel_lsc_prefetch( vector cur_max; vector cur_sum; - bool need_comp = false; - cur_max = -3e38f; cur_sum = 0; constexpr int num_P_tiles = REG_N / REG_M; matrix rQ; matrix rO; - auto q_tokens_left = q_len; + auto q_tokens_left = q_len;// - q_start; static_assert(q_step == REG_N); static_assert(kv_step == REG_K); if (q_tokens_left < 0) q_tokens_left = 0; if (q_tokens_left > q_step) q_tokens_left = q_step; -#if SPARSE_BLOCK_SIZE > 1 - // printf("wg:%d.%d q: %d, +%d kv: %d, x-attn: %p\n", 0, wg_local_id, q_start, q_tokens_left, kv_stop, reinterpret_cast(sparse_mask_base)); -#endif - if (q_tokens_left > 0) { lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); #pragma unroll @@ -786,12 +962,12 @@ void pa_kernel_lsc_prefetch( } } - lsc::block_2d_desc b2dK(k_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); - lsc::block_2d_desc b2dV(v_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); static_assert(wg_local_size == 16); - lsc::block_2d_desc prefetch_K(k_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); - lsc::block_2d_desc prefetch_V(v_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + lsc::block_2d_desc prefetch_K(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc prefetch_V(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); constexpr int blk_stride = CMFLA_NUM_KV_HEADS*CMFLA_HEAD_SIZE*CMPA_BLOCK_SZ; int causal_left = q_start+past_lens; @@ -801,15 +977,14 @@ void pa_kernel_lsc_prefetch( uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; //# St = k @ Qt - matrix St; // = ugemm_KQ(slm_K, rQ, slm_offset); + matrix St; { constexpr int num_K = kv_step/REG_M; auto St2 = St.format(); matrix Kmat; - //cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); - prefetch_K.set_base_ptr((reinterpret_cast(k_base)+prefetch_block_id*blk_stride)); + prefetch_K.set_base_ptr((reinterpret_cast(k_cache_base)+prefetch_block_id*blk_stride)); prefetch_K.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); cm_prefetch(prefetch_K.set_block_x(0)); @@ -825,8 +1000,7 @@ void pa_kernel_lsc_prefetch( } } #endif - - b2dK.set_base_ptr((reinterpret_cast(k_base)+cur_block_id*blk_stride)); + b2dK.set_base_ptr((reinterpret_cast(k_cache_base)+cur_block_id*blk_stride)); b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); cm_load(Kmat.format(), b2dK.set_block_x(0)); #pragma unroll @@ -838,7 +1012,6 @@ void pa_kernel_lsc_prefetch( #pragma unroll for(int ri = 1; ri < head_size/REG_K; ri++) { - //cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); cm_prefetch(prefetch_K.set_block_x(ri*REG_K)); cm_load(Kmat.format(), b2dK.set_block_x(ri*REG_K)); #pragma unroll @@ -864,18 +1037,18 @@ void pa_kernel_lsc_prefetch( for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; } - // show(St); + //show(St); auto max_comp = online_softmax_update(St, cur_max, cur_sum); matrix P; Transpose2DMatrix(St, P); - prefetch_V.set_base_ptr((reinterpret_cast(v_base)+prefetch_block_id*blk_stride)); + prefetch_V.set_base_ptr((reinterpret_cast(v_cache_base)+prefetch_block_id*blk_stride)); prefetch_V.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); - b2dV.set_base_ptr((reinterpret_cast(v_base)+cur_block_id*blk_stride)); + b2dV.set_base_ptr((reinterpret_cast(v_cache_base)+cur_block_id*blk_stride)); b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); - if (need_comp == false) { + if (kv_pos == 0) { // ugemm_PV0(slm_V, P, rO, slm_offset); auto P2 = P.format(); #pragma unroll @@ -889,11 +1062,8 @@ void pa_kernel_lsc_prefetch( 0, Vmat.format(), P2.row(p).format()); - // show(rO[ri + p].format()); } } - - need_comp = true; } else { //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); @@ -921,7 +1091,6 @@ void pa_kernel_lsc_prefetch( rO[ri + p].format(), Vmat.format(), P2.row(p).format()); - // show(rO[ri + p].format()); } } } @@ -950,206 +1119,5 @@ void pa_kernel_lsc_prefetch( cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); } } -#endif - -template -void sdpa_kernel( - uint slm_K, - uint slm_V, - int wg_local_id, - int local_size, - int q_start, - int kv_stop, - int q_len, - int kv_len, - SurfaceIndex query [[type("buffer_t")]], - SurfaceIndex key [[type("buffer_t")]], - SurfaceIndex value [[type("buffer_t")]], - SurfaceIndex output [[type("buffer_t")]], - uint q_off, - uint k_off, - uint v_off, - uint o_off) { - - constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); - constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; - constexpr uint kv_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); - - vector cur_max; - vector cur_sum; - - cur_max = -3e38f; - cur_sum = 0; - - matrix rQ; - auto q_tokens_left = q_len; - static_assert(q_step == REG_N); - static_assert(kv_step == REG_K); - - if (q_tokens_left < 0) q_tokens_left = 0; - if (q_tokens_left > q_step) q_tokens_left = q_step; - - if (q_tokens_left > 0) { - // load as many as possible given one address - if constexpr (head_size == 128 || head_size == 64) { - matrix QmatI32; - cm_load_2d(QmatI32, query, q_off, q_pitch); - #pragma unroll - for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { - Transpose2DMatrix(QmatI32.select(0, k), rQ[ri].format()); - rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); - } - } else { - #pragma unroll - for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { - matrix QmatI32; - cm_load_2d(QmatI32, query, q_off + k * sizeof(uint), q_pitch); - Transpose2DMatrix(QmatI32, rQ[ri].format()); - rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); - } - } - } - - constexpr int num_P_tiles = REG_N / REG_M; - matrix rO; - int causal_left = q_start; - - constexpr uint slm_buff_size = kv_step * head_size * sizeof(half); - int slm_buff_id_write = 0; - int slm_buff_id_read = 0; - - auto load_slm_KV = [&](int kv_pos) { - //if (kv_pos < 1024000) return; - int kv_tokens = kv_stop - kv_pos; - if (kv_tokens <= 0) return; - uint slm_offset = (slm_buff_id_write & 3) * slm_buff_size; - slm_buff_id_write ++; - - // non-tail branch is faster - if (wg_local_id < local_size/2) { - //if (kv_pos > 1024000) { - matrix temp; - for(int k = REG_K * wg_local_id; k < head_size; k += REG_K*(local_size/2)) { - cm_load_2d(temp, key, k_off + k*sizeof(half), kv_pitch); - cm_slm_block_write(slm_K, - slm_offset + k * 2 * REG_M * sizeof(half), - temp.format()); - } - } else { - //if (kv_pos > 1024000) { - // read 16x16 XMX-B matrix (1x REG_N in Xe2, 2x REG_N in Xe1) - constexpr int VK_STEP = 16; - static_assert((VK_STEP % REG_N) == 0); - matrix temp2; - matrix temp_vnni; - //b2dV.set_block_y(kv_pos); - - static_assert((head_size % VK_STEP) == 0); - #pragma unroll - for(int k = VK_STEP * (wg_local_id-local_size/2); k < head_size; k += VK_STEP * (local_size/2)) { - cm_load_2d(temp2, value, v_off + k*sizeof(half), kv_pitch); - - #pragma unroll - for(int p = 0; p < VK_STEP/REG_N; p++) { - temp_vnni.select(0, 0) = temp2.select(0, p*REG_N); - temp_vnni.select(0, 1) = temp2.select(1, p*REG_N); - // show(temp_vnni); - cm_slm_block_write(slm_V, slm_offset + (k + p*REG_N) * REG_K * sizeof(half), temp_vnni.format()); - } - } - } - k_off += kv_step * kv_pitch; - v_off += kv_step * kv_pitch; - // printf(" diff= %lu\n", get_clock() - clk0); - }; - - load_slm_KV(0); - load_slm_KV(kv_step); - - cm_slm_fence(CM_LOCAL_BARRIER); - cm_sbarrier(1); - for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step, - slm_buff_id_read ++) { - // - // load0->0, signal1, - // [load1->1, wait2, signal2, read0] - // [load2->2, wait3, signal3, read1] - // [load3->3, wait4, signal4, read2] - // [load4->0, wait5, signal5, read3] - // - // after wait4, all workers have reached signal3, so: - // - all workers have finished load2 & read0. - // - we can start to load 4 into SLM slot 0 (i & 3) safely - // - we can start to read 2 ((i-2) & 3) safely - // - cm_fence(CM_LOCAL_BARRIER); - cm_sbarrier(0); - - load_slm_KV(kv_pos + 2*kv_step); - - if (kv_pos + kv_step < kv_stop) - cm_sbarrier(1); - - //if (kv_pos < 1024000) continue; - uint slm_offset = (slm_buff_id_read & 3) * slm_buff_size; - - //=========================================================== 1807 ~ 3247 - //# St = k @ Qt - matrix St = ugemm_KQ(slm_K, rQ, slm_offset); - - if constexpr (use_causal_mask) { - if (causal_left < kv_step) { - vector cmask = 0.0f; - int p = causal_left + 1; - int v = 0; - for(; p < 0; p++) { - cmask[v] = -3.4e38f; - if (v < q_step - 1) v++; - } - for(; p < kv_step; p++) { - cmask[v] = -3.4e38f; - St[p] = cm_add(St[p], cmask); - if (v < q_step - 1) v++; - } - //if (wg_local_id == 0) show(St);return; - } - causal_left -= kv_step; - } - - // mask off k-tails - int kv_tokens = kv_stop - kv_pos; - for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; - - //show(St); - auto max_comp = online_softmax_update(St, cur_max, cur_sum); - - matrix P; - Transpose2DMatrix(St, P); - - if (kv_pos == 0) - ugemm_PV0(slm_V, P, rO, slm_offset); - else - ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); - } - - if (q_tokens_left > 0) { - //# save cur_O/cur_sum.transpose(0, 1) - matrix cur_O_f16; - cur_sum = cm_inv(cur_sum); - - #pragma unroll - for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - auto cO = rO[ri + p].format(); - #pragma unroll - for(int r = 0; r < cO.n_rows(); r++) { - cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); - } - } - // if (i == args_verbose) show(cur_O_f16); - cm_store_2d(cur_O_f16, output, o_off + k*sizeof(half), o_pitch); - } - } -} +#endif \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm index c9198a8c89dfac..d24e55d52cabc3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -32,8 +32,13 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const int32_t* block_indices [[type("svmptr_t")]], const int32_t* block_indices_begins [[type("svmptr_t")]], const int32_t* subsequence_begins [[type("svmptr_t")]], +#if KV_CACHE_COMPRESSION_PER_TOKEN + uint8_t* key_cache [[type("svmptr_t")]], + uint8_t* value_cache [[type("svmptr_t")]], +#else half* key_cache [[type("svmptr_t")]], - half* value_cache [[type("svmptr_t")]], + half* value_cache [[type("svmptr_t")]], +#endif uint32_t key_pitch, uint32_t key_offset, uint32_t value_pitch, @@ -84,6 +89,30 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx; + #if KV_CACHE_COMPRESSION_PER_TOKEN + // Assume: K_HEAD_SIZE == K_HEAD_SIZE + auto quantize_and_store = [&](vector data, uchar* out, uint out_offset, uint token_pos) { + uint scale_offset = out_offset + K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + token_pos * sizeof(half); + half max_val = cm_reduced_max(data); + half min_val = cm_reduced_min(data); + half scale_val = half(0.0); + half zp_val = half(0.0); + if(max_val == min_val) { + scale_val = half(0.0); + zp_val = max_val; + } else { + scale_val = 255.0 / (max_val - min_val); + zp_val = (0.0 - min_val) * scale_val; + } + vector dequant_data = cm_mul(data, scale_val) + zp_val; + vector data_u8 = cm_rnde(dequant_data); + cm_ptr_store((uint32_t*)(out + out_offset + token_pos * K_HEAD_SIZE), 0, data_u8.format()); + half *out_scale_zp = (half*)(out + scale_offset); + out_scale_zp[0] = (max_val - min_val) / 255.0; + out_scale_zp[PAGED_ATTENTION_BLOCK_SIZE] = zp_val; + }; + #endif + { uint block_k_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; uint key_out_offset = block_k_base_offset + token_start_pos * K_HEAD_SIZE; @@ -91,7 +120,12 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( vector key_data; key_data.format() = cm_ptr_load((int*)key, key_in_offset * (int)sizeof(half)); + + #if KV_CACHE_COMPRESSION_PER_TOKEN + quantize_and_store(key_data, (uchar*)key_cache, block_k_base_offset, token_start_pos); + #else cm_ptr_store((int*)key_cache, key_out_offset * (int)sizeof(half), key_data.format()); + #endif } { uint block_v_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_V_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; @@ -106,6 +140,10 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( vector value_data; value_data.format() = cm_ptr_load((int*)value, value_in_offset * (int)sizeof(half)); + #if KV_CACHE_COMPRESSION_PER_TOKEN + quantize_and_store(value_data, (uchar*)value_cache, block_v_base_offset, token_start_pos); + #else cm_ptr_store((int*)value_cache, value_out_offset * (int)sizeof(half), value_data.format()); + #endif } } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm index c69a19bcab5945..8c9993b8e8612b 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm @@ -26,15 +26,22 @@ namespace KERNEL_NAME { //extern "C" _GENX_MAIN_ void pa_multi_token( extern "C" _GENX_MAIN_ void KERNEL_NAME( + //query [q_len, num_heads, S] half* query [[type("svmptr_t")]], - half* key [[type("svmptr_t")]], - half* value [[type("svmptr_t")]], +#if CMPA_KVCACHE_U8 + int8_t* k_cache [[type("svmptr_t")]], + int8_t* v_cache [[type("svmptr_t")]], +#else + half* k_cache [[type("svmptr_t")]], + half* v_cache [[type("svmptr_t")]], +#endif int32_t* past_lens [[type("svmptr_t")]], int32_t* block_indices [[type("svmptr_t")]], int32_t* block_indices_begins [[type("svmptr_t")]], int32_t* subsequence_begins [[type("svmptr_t")]], #if SPARSE_BLOCK_SIZE > 1 bool* sparse_block_mask [[type("svmptr_t")]], + bool* sparse_block_mask_wg [[type("svmptr_t")]], #endif half* output [[type("svmptr_t")]], int q_len) { @@ -44,16 +51,26 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( constexpr int num_kv_heads = CMFLA_NUM_KV_HEADS; constexpr int pa_block_sz = CMPA_BLOCK_SZ; //# query [q_len, num_heads, S] - //# key [kv_len, num_heads, S] - //# value [kv_len, num_heads, S] - //# sparse_block_mask [num_heads, q_blocks, kv_blocks] + //# k_cache [kv_len, num_heads, S] + //# v_cache [kv_len, num_heads, S] +#if CMPA_KVCACHE_U8 + constexpr uint K_SLM_SIZE = (4*kv_step * head_size * sizeof(half)); + constexpr uint V_SLM_SIZE = (4*kv_step * head_size * sizeof(half)); + constexpr uint Q_SLM_SIZE = 0;//(q_step * head_size * sizeof(half)) * local_size; + + cm_slm_init(K_SLM_SIZE + V_SLM_SIZE + Q_SLM_SIZE); + auto slm_K = cm_slm_alloc(K_SLM_SIZE); + auto slm_V = cm_slm_alloc(V_SLM_SIZE); + +#endif auto batch = cm_group_id(0); auto h = cm_group_id(1); auto hkv = h / (num_heads/num_kv_heads); auto wg_id = cm_group_id(2); // each work-group handles a sequence auto wg_local_id = cm_local_id(2); int local_size = cm_local_size(2); + int q_start_sg, kv_start, kv_seq_len, q_len_sg; // multiple work-groups are required to split a sequence, @@ -91,57 +108,71 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( --------------------------------- each grid can be [q_len_per_trunk, q_len_per_trunk]. For each trunk, [q_len_per_trunk, past_q_lens] must be calculated. Such as: `20`,`21`. but for the 22, - causal mask optimization can be applied. different wgs would has different kv stop. + casual mask optimization can be applied. differnt wgs would has different kv stop. //todo:kv_stop is wg level, should we change to sg level? + sglevel would cause sgs in one wg diverge. so leave for now. also one wg has same kvstop makes eaiser for kv copying/loading into SLM/cache. */ kv_stop = (wg_id + 1) * wg_seq_len + past_q_lens; if (kv_stop > kv_seq_len) kv_stop = kv_seq_len; } - - // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop); - // qkv fused - // constexpr uint num_total_heads = num_heads + num_kv_heads * 2; - // uint q_offset = (q_start*num_total_heads + h)*head_size; - // uint k_offset = (kv_start*num_total_heads + num_heads + hkv)*head_size; - // uint v_offset = (kv_start*num_total_heads + num_heads + num_kv_heads + hkv)*head_size; + // printf("###########wg:%d.%d q: %d, +%d kv: %d, +%d, kvstop:%d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop); //Q/O[B, L, H, S] uint q_offset = (q_start_sg*num_heads + h)*head_size; - uint o_offset = (q_start_sg*num_heads + h)*head_size; - - //K/V[block_num, kv_heads, block_sz, head_sz] - uint k_offset = hkv*head_size*pa_block_sz; - uint v_offset = hkv*head_size*pa_block_sz; #if SPARSE_BLOCK_SIZE > 1 //# sparse_block_mask [num_heads, q_blocks, kv_blocks] auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE; int q_blocks = (q_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE; int kv_blocks = (kv_seq_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE; + //[self.num_heads, q_block_num, kv_block_num] bool* block_mask_base = sparse_block_mask + (h * q_blocks + q_start_block)*kv_blocks; + //[self.num_heads, wg_count_along_query, kv_block_num)] + bool* wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id)*kv_blocks; // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, q_blocks, kv_blocks, sparse_block_mask, block_mask_base); #endif -#if USE_LSC == 1 - pa_kernel_lsc_prefetch( - wg_local_id, - q_start_sg, //q_start for SG, - kv_stop, - q_len_sg, //q_step, - kv_seq_len, //kv_len, not used for now - reinterpret_cast(query + q_offset), - reinterpret_cast(key + k_offset), - reinterpret_cast(value + v_offset), +#if CMPA_KVCACHE_U8 + uint kv_offset = hkv*(head_size+4)*pa_block_sz; + pa_lsc_u8( + slm_K, + slm_V, + wg_local_id, + local_size, + q_start_sg, //q_start for SG, + kv_stop, + q_len_sg, //q_step, + kv_seq_len, //kv_len, + reinterpret_cast(query + q_offset), + reinterpret_cast(k_cache + kv_offset), + reinterpret_cast(v_cache + kv_offset), #if SPARSE_BLOCK_SIZE > 1 - reinterpret_cast(block_mask_base), + reinterpret_cast(block_mask_base), + reinterpret_cast(wg_block_mask_base), + #endif - reinterpret_cast(output + o_offset), - past_q_lens, - block_indices); + reinterpret_cast(output + q_offset), + past_q_lens, + block_indices); #else - static_assert(0); + uint kv_offset = hkv*head_size*pa_block_sz; + pa_kernel_lsc_prefetch_f16( + wg_local_id, + q_start_sg, //q_start for SG, + kv_stop, + q_len_sg, //q_step, + kv_seq_len, //kv_len, + reinterpret_cast(query + q_offset), + reinterpret_cast(k_cache + kv_offset), + reinterpret_cast(v_cache + kv_offset), +#if SPARSE_BLOCK_SIZE > 1 + reinterpret_cast(block_mask_base), + reinterpret_cast(wg_block_mask_base), +#endif + reinterpret_cast(output + q_offset), + past_q_lens, + block_indices); #endif } - -} // NAMESPACE \ No newline at end of file +} // namespace KERNEL_NAME \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index bb7ea162349d72..84bf0accb61383 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -45,9 +45,6 @@ #define KV_PARTITION_STEP_NUM (KV_PARTITION_SIZE / KV_STEP) -#define KV_SCALE_ZP_SIZE 0 // 4: scale/zp size - - #define DEBUG_ENABLE 0 #if DEBUG_ENABLE template @@ -103,7 +100,7 @@ void show(vector vec) { //prepack [K, N] to [K/2, N, 2] layout. template -inline void prepackAsVNNIWidth2(matrix_ref input, matrix_ref out) { +inline void prepack_to_VNNI_W2(matrix_ref input, matrix_ref out) { #pragma unroll for (int r = 0; r < K/2; r++) { out.row(r).select(0) = input.row(r*2); @@ -498,7 +495,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( VmatNormal[r] = 0; } } - prepackAsVNNIWidth2(VmatNormal, Vmat.format()); + prepack_to_VNNI_W2(VmatNormal, Vmat.format()); #else cm_load(Vmat[0].format(), b2dV.set_block_y(kv_pos)); #endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp index 7b28c001879f23..45c956ae54b173 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp @@ -27,7 +27,7 @@ struct PagedAttentionImplementationManager : public ImplementationManager { }; static constexpr std::array supported_kv_types = { ov::element::f16, - // ov::element::i8, + ov::element::i8, }; auto& engine = node.get_program().get_engine(); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 5f505dc3e329f1..560f5809e83070 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -66,10 +66,25 @@ inline size_t get_kv_len(const RuntimeParams& params, const PagedAttentionStage& return 0; // Fallback case, should not be reached } +inline size_t get_input_kv_len(const RuntimeParams& params) { + auto key_shape = params.input_layouts[PagedAttentionInputIdx::KEY].get_shape(); + const size_t kv_len = key_shape[key_shape.size() - 2]; + return kv_len; +} + inline size_t get_aligned_kv_len(const size_t kv_len) { return (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE * PA_KV_CACHE_BLOCK_SIZE; } +inline bool get_kv_compressed(const RuntimeParams& params) { + auto key_cache_layout = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE]; + if (data_type_traits::is_i8_u8(key_cache_layout.data_type)) { + return true; + } else { + return false; + } +} + int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size = 16) { // Since at prefill stage Q, K, V inputs may contain multiple sequences with arbitrary // target sequence lengths each (shape is [sequences_num * target_seq_len, num_heads * head_size]), @@ -268,10 +283,18 @@ JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kerne jit.make("KV_HEADS_NUM", desc->kv_heads_num); jit.make("K_HEAD_SIZE", desc->k_head_size); jit.make("V_HEAD_SIZE", desc->v_head_size); - jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size); - jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size); jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + if (get_kv_compressed(params)) { + jit.make("KV_CACHE_COMPRESSION_PER_TOKEN", 1); + jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size + 4); + jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size + 4); + } else { + jit.make("KV_CACHE_COMPRESSION_PER_TOKEN", 0); + jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size); + jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size); + } + return jit; } @@ -302,7 +325,8 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() const auto desc = params.typed_desc(); // auto rtp = static_cast(rt_params); - const size_t kv_len = get_max_context_len(params); + // const size_t kv_len = get_max_context_len(params); + const size_t kv_len = get_input_kv_len(params); const size_t kv_heads_num = desc->kv_heads_num; const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; @@ -372,7 +396,8 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() if (DEBUG_ENABLED) { // Debug std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " << "kv_len: " << kv_len << ", key_pitch: " << key_pitch << ", key_offset: " << key_offset << ", value_pitch: " << value_pitch - << ", value_offset: " << value_offset << ", "<< std::endl; + << ", value_offset: " << value_offset << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" + << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } // TODO: support multiple sequences @@ -429,6 +454,12 @@ JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_i jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE); jit.make("SPARSE_BLOCK_SIZE", xattn_block_size); jit.make("Q_STEP", get_q_step(xe_arch, true)); + + if (get_kv_compressed(params)) { + jit.make("CMPA_KVCACHE_U8", 1); + } else { + jit.make("CMPA_KVCACHE_U8", 0); + } // for (auto& it : jit) { // std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl; // } @@ -509,7 +540,11 @@ JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_ jit.make("KV_HEADS_NUM", desc->kv_heads_num); jit.make("Q_STEP", get_q_step(xe_arch, true)); - jit.make("KV_CACHE_COMPRESSION", 0); + if (get_kv_compressed(params)) { + jit.make("KV_CACHE_COMPRESSION", 1); + } else { + jit.make("KV_CACHE_COMPRESSION", 0); + } return jit; } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp index fb5cf4631bace7..d2ccc46bf01e37 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp @@ -26,7 +26,7 @@ struct PagedAttentionOpt : public ImplementationManager { }; static constexpr std::array supported_kv_types = { #if ENABLE_PA_CM_PATH - ov::element::i8, + ov::element::f32, #else ov::element::f32, ov::element::f16, From 55ba7c3ce8776fc9b5a300fbfb0c538b30564043 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 22 Sep 2025 13:42:56 +0800 Subject: [PATCH 22/96] refactor: split into pa_common and sdpa_common, which include attention_common. --- .../impls/cm/include/cm_attention_common.hpp | 415 +++++++++ .../graph/impls/cm/include/cm_pa_common.hpp | 475 ++++++++++ .../graph/impls/cm/include/cm_sdpa_common.hpp | 847 ++++++------------ .../src/graph/impls/cm/pa_multi_token.cm | 2 +- 4 files changed, 1159 insertions(+), 580 deletions(-) create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp new file mode 100644 index 00000000000000..325cd25fc5ba59 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp @@ -0,0 +1,415 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#include +#include + +//# CM-compiler is C++17 +static_assert(__cplusplus >= 201703L); +//# static_assert(__cplusplus >= 202002L); +//# static_assert(__cplusplus >= 202302L); + +#define SystolicDepth 8 +#define RepeatCount 8 +#define VNNI_WIDTH 2 +#define REG_K (SystolicDepth * VNNI_WIDTH) +#define REG_M RepeatCount +//REG_N +// Xe1: 8 +// Xe2: 16 +#define REG_N (CM_GRF_WIDTH/32) + +#define kv_step REG_K +#define q_step REG_N + +constexpr float scale_factor = CMFLA_SCALE_FACTOR; + +static_assert(q_step == 16 || q_step == 8); +static_assert(kv_step == 16); +static_assert(CM_HAS_DPAS); + +#define DEBUG_SHOW 1 +#if !DEBUG_SHOW +template +void show(const matrix mat, bool isfloat=true) { +} +#else +template +void show(const matrix mat, bool isfloat=true) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + if (isfloat) + printf("%8.4f,", mat[m][n]); + else + printf("%8d,", mat[m][n]); + + } + printf("],\n"); + } + printf("]\n"); +} +#endif +template +CM_INLINE void Transpose_16x16(matrix_ref in, + matrix_ref out) { + matrix bBuf; + bBuf.row(0) = in.template select<4, 1, 4, 4>(0, 0); // 0,4,8,c + bBuf.row(1) = in.template select<4, 1, 4, 4>(4, 0); // 0,4,8,c + bBuf.row(2) = in.template select<4, 1, 4, 4>(8, 0); // 0,4,8,c + bBuf.row(3) = in.template select<4, 1, 4, 4>(12, 0); // 0,4,8,c + bBuf.row(4) = in.template select<4, 1, 4, 4>(0, 1); // 1,5,9,d + bBuf.row(5) = in.template select<4, 1, 4, 4>(4, 1); // 1,5,9,d + bBuf.row(6) = in.template select<4, 1, 4, 4>(8, 1); // 1,5,9,d + bBuf.row(7) = in.template select<4, 1, 4, 4>(12, 1); // 1,5,9,d + bBuf.row(8) = in.template select<4, 1, 4, 4>(0, 2); // 2,6,a,e + bBuf.row(9) = in.template select<4, 1, 4, 4>(4, 2); // 2,6,a,e + bBuf.row(10) = in.template select<4, 1, 4, 4>(8, 2); // 2,6,a,e + bBuf.row(11) = in.template select<4, 1, 4, 4>(12, 2); // 2,6,a,e + bBuf.row(12) = in.template select<4, 1, 4, 4>(0, 3); // 3,7,b,f + bBuf.row(13) = in.template select<4, 1, 4, 4>(4, 3); // 3,7,b,f + bBuf.row(14) = in.template select<4, 1, 4, 4>(8, 3); // 3,7,b,f + bBuf.row(15) = in.template select<4, 1, 4, 4>(12, 3); // 3,7,b,f + + out.row(0) = bBuf.template select<4, 1, 4, 4>(0, 0); // 0 + out.row(1) = bBuf.template select<4, 1, 4, 4>(4, 0); // 1 + out.row(2) = bBuf.template select<4, 1, 4, 4>(8, 0); // 2 + out.row(3) = bBuf.template select<4, 1, 4, 4>(12, 0); // 3 + out.row(4) = bBuf.template select<4, 1, 4, 4>(0, 1); // 4 + out.row(5) = bBuf.template select<4, 1, 4, 4>(4, 1); // 5 + out.row(6) = bBuf.template select<4, 1, 4, 4>(8, 1); // 6 + out.row(7) = bBuf.template select<4, 1, 4, 4>(12, 1); // 7 + out.row(8) = bBuf.template select<4, 1, 4, 4>(0, 2); // 8 + out.row(9) = bBuf.template select<4, 1, 4, 4>(4, 2); // 9 + out.row(10) = bBuf.template select<4, 1, 4, 4>(8, 2); // a + out.row(11) = bBuf.template select<4, 1, 4, 4>(12, 2); // b + out.row(12) = bBuf.template select<4, 1, 4, 4>(0, 3); // c + out.row(13) = bBuf.template select<4, 1, 4, 4>(4, 3); // d + out.row(14) = bBuf.template select<4, 1, 4, 4>(8, 3); // e + out.row(15) = bBuf.template select<4, 1, 4, 4>(12, 3); // f +} + +template +CM_INLINE void Transpose_8x8(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); + temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); + temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); + temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); + temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); + temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); + temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); + temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); + + out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); + out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); + out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); + out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); + out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); + out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); + out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); + out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); +} + +// function templates cannot be partially specialized; use overloading to achieve the same effect +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in, out); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_16x16(in, out); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); + Transpose_8x8(in.select<8, 1, 8, 1>(8,0), out.select<8, 1, 8, 1>(0,8)); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); + Transpose_8x8(in.select<8, 1, 8, 1>(0,8), out.select<8, 1, 8, 1>(8,0)); +} + +template +CM_INLINE void slm_read_2d(matrix_ref out, uint slm, int offset) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_slm_block_read(slm, GENX_DWALIGNED, offset + i*n_stride*sizeof(T), out.row(i)); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_svm_block_read(base + i * pitch, out[i]); + } +} + +template +CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + out.row(i).format() = cm_load(base, offset + i * pitch); + } +} + +template +CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + out.row(i).format() = cm_load(base, offset + i * pitch); + } +} + +template +CM_INLINE void cm_store_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_store(base, offset + i * pitch, out.row(i).format()); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, vector_ref offsets) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_svm_block_read(base + offsets[i], out[i]); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch, n_rows--) { + if (n_rows > 0) cm_svm_block_read(base, out[i]); + } +} + +template +CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch) { + cm_svm_block_write(base, out[i]); + } +} + +template +CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch) { + if (i < n_rows) cm_svm_block_write(base, out[i]); + } +} + +CM_INLINE uint64_t get_clock() { + auto clk = cm_clock(); + return ((uint64_t)clk[1]) << 32 | clk[0]; +} + + +template +inline matrix ugemm_KQ(uint slm_K, matrix_ref Qt, uint slm_offset = 0) { + matrix St; + constexpr int num_K = _kv_step/REG_M; + auto St2 = St.format(); + + matrix Kmat; + cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + // if (cm_local_id(2) == 3 && cm_group_id(2) == 0) { + // show(Kmat.format()); + // } + #pragma unroll + for(int k = 0; k < num_K; k++) + St2.row(k) = cm_dpas(0, Qt[0].format(), Kmat[k].format()); + + #pragma unroll + for(int ri = 1; ri < num_Qt; ri++) { + cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); + #pragma unroll + for(int k = 0; k < num_K; k++) { + St2.row(k) = cm_dpas(St2.row(k), Qt[ri].format(), Kmat[k].format()); + } + } + return St; +} + +template +inline void ugemm_PV0(uint slm_V, matrix_ref P, matrix_ref rO, uint slm_offset = 0) { + constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; + + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri = 0; k < _head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); + // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { + // show(Vmat.format()); + // } + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + 0, + Vmat.format(), + P2.row(p).format()); + //show(rO[ri + p].format()); + } + } +} + +template +inline void ugemm_PV1(uint slm_V, matrix_ref P, vector_ref max_comp, + matrix_ref rO, uint slm_offset = 0) { + constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri=0; k < _head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); + // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { + // show(Vmat.format()); + // } + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < REG_M; r++) + cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); + } + + //show(rO[ri].format()); + + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + rO[ri + p].format(), + Vmat.format(), + P2.row(p).format()); + //if (kv_pos == args_verbose) show(rO[ri + p].format()); + } + // if (kv_pos == args_verbose) show(cur_O.format()); + } +} + +template +vector online_softmax_update(matrix_ref St, vector_ref cur_max, vector_ref cur_sum) { + vector new_max_t; + new_max_t = cm_max(St[0], St[1]); + for(int r = 2; r < St.n_rows(); r++) new_max_t = cm_max(new_max_t, St[r]); + new_max_t = cm_max(new_max_t, cur_max); + + // Pt = torch.exp(St - new_max) + constexpr float log2e = 1.4426950408889634f; + for(int r = 0; r < St.n_rows(); r++) St[r] = cm_exp((St[r] - new_max_t)*log2e); + + vector row_sum_t; + row_sum_t = cm_add(St[0], St[1]); + for(int r = 2; r < St.n_rows(); r++) row_sum_t = cm_add(row_sum_t, St[r]); + + vector max_comp; + max_comp = cm_exp((cur_max - new_max_t)*log2e); + cur_sum = cm_mul(cur_sum, max_comp); + cur_sum = cm_add(cur_sum, row_sum_t); + cur_max = new_max_t; + return max_comp; +} + +#ifdef CM_HAS_LSC_UNTYPED_2D + #define cm_load_normal cm_load + #define cm_load_transpose cm_load + #define cm_load_vnni cm_load + #define cm_store_normal cm_store +#else + // simulation of LSC API using SVM API + template + inline void cm_load_normal(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, Res.select(i*BlockW)); + } + } + + template + inline void cm_load_transpose(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + matrix temp; + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, temp[i]); + } + Transpose2DMatrix(temp, Res.format()); + } + + // in VNNI case, NBlocks is increasing along X dimension (increase cache-line usage) + template + inline void cm_load_vnni(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1 || NBlocks == 2); + // each block must be a full XMX B matrix + static_assert(BlockH == REG_K); + static_assert(BlockW == REG_N); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + matrix temp; + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, temp[i]); + } + + auto out_vnni = Res.format(); + #pragma unroll + for(int i = 0; i < NBlocks; i ++) { + out_vnni.select(i*(BlockH/2), 0) = temp.select(0, i*BlockW); + out_vnni.select(i*(BlockH/2), 1) = temp.select(1, i*BlockW); + } + } + + template + inline void cm_store_normal(const lsc::block_2d_desc &Desc, vector_ref Res) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_write(base + i * pitch, Res.select(i*BlockW)); + } + } +#endif + +//=============================================================================================== +template +constexpr void apply_causal_mask(matrix_ref St) { + if constexpr (i < N) { + St.row(i).select(0) = -3.4e38f; + apply_causal_mask(St); + } +} + +//prepack [K, N] to [K/2, N, 2] layout. +template +inline void prepackAsVNNIWidth2(matrix_ref input, matrix_ref out) { + #pragma unroll + for (int r = 0; r < K/2; r++) { + out.row(r).select(0) = input.row(r*2); + out.row(r).select(1) = input.row(r*2+1); + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp new file mode 100644 index 00000000000000..0535013926ab37 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp @@ -0,0 +1,475 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#include "cm_attention_common.hpp" + +#if CMPA_KVCACHE_U8 +template +void pa_lsc_u8( + uint slm_K, + uint slm_V, + int wg_local_id, + int local_size, + int q_start, + int kv_stop, + int q_len, + int kv_len, + svmptr_t q_base [[type("svmptr_t")]], + svmptr_t k_cache_base [[type("svmptr_t")]], + svmptr_t v_cache_base [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + svmptr_t sparse_mask_base [[type("svmptr_t")]], + svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], +#endif + svmptr_t o_base [[type("svmptr_t")]], + int32_t past_lens, + int32_t* block_indices [[type("svmptr_t")]]) { + + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_q_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + //[block_num, kv_heads, block_size, head_size] + constexpr uint kv_pitch = head_size * sizeof(uint8_t); + + vector cur_max; + vector cur_sum; + + cur_max = -3e38f; + cur_sum = 0; + constexpr int num_P_tiles = REG_N / REG_M; + matrix rQ; + matrix rO; + + auto q_tokens_left = q_len; + static_assert(q_step == REG_N); + static_assert(kv_step == REG_K); + + if (q_tokens_left < 0) q_tokens_left = 0; + if (q_tokens_left > q_step) q_tokens_left = q_step; + + if (q_tokens_left > 0) { + lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); + #pragma unroll + for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { + cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + + lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + constexpr int quan_blk_stride = CMFLA_NUM_KV_HEADS * (CMFLA_HEAD_SIZE+4) * CMPA_BLOCK_SZ * sizeof(uint8_t); + int causal_left = q_start+past_lens; + + constexpr uint slm_buff_size = kv_step * head_size * sizeof(half); + int slm_buff_id_write = 0; + int slm_buff_id_read = 0; + +#if SPARSE_BLOCK_SIZE > 1 + auto skip_compute = [&](int kv_pos) { + auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); + + return !sparse_mask; + }; + auto skip_load = [&](int kv_pos) { + auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(wg_sparse_mask_base) + kv_start_block); + return !sparse_mask; + }; +#endif + + auto load_slm_KV = [&](int kv_pos) { + if (kv_pos < kv_stop) { +#if SPARSE_BLOCK_SIZE > 1 + if (skip_load(kv_pos)) { + slm_buff_id_write++; + return; + } +#endif + auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; + uint32_t dscale_offset = cur_block_id*quan_blk_stride + \ + CMPA_BLOCK_SZ * head_size * sizeof(uint8_t) + kv_pos%CMPA_BLOCK_SZ*sizeof(half); + + uint slm_offset = (slm_buff_id_write & 3) * slm_buff_size; + vector dscale; + vector zp; + int kv_left = (kv_stop-kv_pos) > kv_step ? kv_step: (kv_stop-kv_pos); + + slm_buff_id_write ++; + if (wg_local_id < local_size/2) { + cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset), dscale); + cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset + CMPA_BLOCK_SZ*sizeof(half)), zp); + + matrix kmat; + auto quanKmat = kmat.format()[1].format(); + b2dK.set_base_ptr(reinterpret_cast(k_cache_base+cur_block_id*quan_blk_stride)); + b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); + + for(int k = REG_K*wg_local_id; k < head_size; k += REG_K*(local_size/2)) { + cm_load(quanKmat.format(), b2dK.set_block_x(k)); + /*@bug: cm compiler in the tail process. + : loop combined with type convert. + for(int r = 0; r < kv_left; r++) { + kmat[r] = quanKmat[r]-zp[r]; + kmat[r] = cm_mul(kmat[r], dscale[r]); + } + wa: unroll all kv_step rows. set 0 to padding rows. + */ + #pragma unroll + for(int r = 0; r < kv_step; r++) { + kmat[r] = quanKmat[r]-zp[r]; + kmat[r] = cm_mul(kmat[r], dscale[r]); + } + //clear unused data to 0. + for(int r = kv_step-1; r >= kv_left; r--) + kmat[r] = 0; + cm_slm_block_write(slm_K, slm_offset + k * kv_step * sizeof(half), kmat.format()); + } + } else { + cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset), dscale); + cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset+CMPA_BLOCK_SZ*sizeof(half)), zp); + + matrix VmatVNNI; + matrix Vmat; + auto quanVmat = Vmat.format().row(1).format(); + b2dV.set_base_ptr(reinterpret_cast(v_cache_base+cur_block_id*quan_blk_stride)); + b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + + #pragma unroll + for(int k = REG_N*(wg_local_id-(local_size/2)); k < head_size; k += REG_N*(local_size/2)) { + cm_load(quanVmat.format(), b2dV.set_block_x(k)); + /*@bug: cm compiler in the tail process. + : loop combined with type convert. + for(int r = 0; r < kv_left; r++) { + Vmat[r] = quanVmat[r]-zp[r]; + Vmat[r] = cm_mul(Vmat[r], dscale[r]); + } + */ + #pragma unroll + for(int r = 0; r < kv_step;r++) { + Vmat[r] = quanVmat[r]-zp[r]; + Vmat[r] = cm_mul(Vmat[r], dscale[r]); + } + + for(int r = kv_step-1; r>=kv_left;r--) { + Vmat[r] = 0; + } + prepackAsVNNIWidth2(Vmat, VmatVNNI); + cm_slm_block_write(slm_V, slm_offset + k * REG_K * sizeof(half), VmatVNNI.format()); + } + } + } + }; + + load_slm_KV(0); + load_slm_KV(kv_step); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(1); + + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step,slm_buff_id_read++) { + + // load0, load1, signal1, + // [wait1, signal2, load2, read0, compute0] + // [wait2, signal3, load3, read1, compute1] + // [wait3, signal4, load4, read2, compute2] + // [wait4, signal5, load5, read3, compute3] + // + // after wait3, all workers have reached signal3, so: + // - all workers have finished load2 & read0. + // - we can start to load 4 into SLM slot 0 (i & 3) safely + // - we can start to read 2 ((i-2) & 3) safely + + + cm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(0); + //if (kv_pos > 1024000) + if (kv_pos + kv_step < kv_stop) + cm_sbarrier(1); + load_slm_KV(kv_pos + kv_step*2); + + +#if SPARSE_BLOCK_SIZE > 1 + if (skip_compute(kv_pos)) { + if constexpr (use_causal_mask) + causal_left -= kv_step; + continue; + } +#endif + { + + uint slm_offset = (slm_buff_id_read & 3) * slm_buff_size; + + //# St = k @ Qt + matrix St = ugemm_KQ(slm_K, rQ, slm_offset); + if constexpr (use_causal_mask) { + // since kv_step == q_step == 16, causal_left is n*kv_step + if (causal_left == 0) { + apply_causal_mask<1>(St); + } else if (causal_left < 0) { + St = -3.4e38f; + } + causal_left -= kv_step; + } else { + int kv_tokens = kv_stop - kv_pos; + // LSC ensures no overflow-access, but mask off k-tails attn-score is still required + for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; + } + auto max_comp = online_softmax_update(St, cur_max, cur_sum); + + matrix P; + Transpose2DMatrix(St, P); + + if (kv_pos == 0) + ugemm_PV0(slm_V, P, rO, slm_offset); + else + ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); + } + } + // cm_sbarrier(0); + if (q_tokens_left == 0) return; + + //# save cur_O/cur_sum.transpose(0, 1) + matrix cur_O_f16; + cur_sum = cm_inv(cur_sum); + + lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); + + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < cO.n_rows(); r++) { + cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + } + } + b2dO.set_block_x(k); + cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); + cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); + } +} + +#else + +template +void pa_kernel_lsc_prefetch_f16( + int wg_local_id, + int q_start, + int kv_stop, // + int q_len, //q_step + int kv_len, //not used for now + svmptr_t q_base [[type("svmptr_t")]], + svmptr_t k_cache_base [[type("svmptr_t")]], + svmptr_t v_cache_base [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + svmptr_t sparse_mask_base [[type("svmptr_t")]], + svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], +#endif + svmptr_t o_base [[type("svmptr_t")]], + int32_t past_lens, + int32_t* block_indices [[type("svmptr_t")]]) { + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + // constexpr uint k_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + // constexpr uint v_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + //[block_num, kv_heads, block_size, head_size] + constexpr uint k_pitch = head_size * sizeof(half); + constexpr uint v_pitch = k_pitch; + + vector cur_max; + vector cur_sum; + + cur_max = -3e38f; + cur_sum = 0; + constexpr int num_P_tiles = REG_N / REG_M; + matrix rQ; + matrix rO; + + auto q_tokens_left = q_len;// - q_start; + static_assert(q_step == REG_N); + static_assert(kv_step == REG_K); + + if (q_tokens_left < 0) q_tokens_left = 0; + if (q_tokens_left > q_step) q_tokens_left = q_step; + + if (q_tokens_left > 0) { + lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); + #pragma unroll + for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { + cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + + lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + + static_assert(wg_local_size == 16); + lsc::block_2d_desc prefetch_K(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc prefetch_V(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + constexpr int blk_stride = CMFLA_NUM_KV_HEADS*CMFLA_HEAD_SIZE*CMPA_BLOCK_SZ; + int causal_left = q_start+past_lens; + + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) { + auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; + //For the last step, duplicate prefetch here. + uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); + auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; + //# St = k @ Qt + matrix St; + { + constexpr int num_K = kv_step/REG_M; + auto St2 = St.format(); + + matrix Kmat; + + prefetch_K.set_base_ptr((reinterpret_cast(k_cache_base)+prefetch_block_id*blk_stride)); + prefetch_K.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + cm_prefetch(prefetch_K.set_block_x(0)); + +#if SPARSE_BLOCK_SIZE > 1 + { + auto kv_start_block = kv_pos/ SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); + if (!sparse_mask) { + if constexpr (use_causal_mask) { + causal_left -= kv_step; + } + continue; + } + } +#endif + b2dK.set_base_ptr((reinterpret_cast(k_cache_base)+cur_block_id*blk_stride)); + b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); + cm_load(Kmat.format(), b2dK.set_block_x(0)); + #pragma unroll + for(int k = 0; k < num_K; k++) + St2.row(k) = cm_dpas( + 0, + rQ[0].format(), + Kmat[k].format()); + + #pragma unroll + for(int ri = 1; ri < head_size/REG_K; ri++) { + cm_prefetch(prefetch_K.set_block_x(ri*REG_K)); + cm_load(Kmat.format(), b2dK.set_block_x(ri*REG_K)); + #pragma unroll + for(int k = 0; k < num_K; k++) { + St2.row(k) = cm_dpas( + St2.row(k), + rQ[ri].format(), + Kmat[k].format()); + } + } + } + if constexpr (use_causal_mask) { + // since kv_step == q_step == 16, causal_left is n*kv_step + if (causal_left == 0) { + apply_causal_mask<1>(St); + } else if (causal_left < 0) { + St = -3.4e38f; + } + causal_left -= kv_step; + } else { + int kv_tokens = kv_stop - kv_pos; + // LSC ensures no overflow-access, but mask off k-tails attn-score is still required + for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; + } + + //show(St); + auto max_comp = online_softmax_update(St, cur_max, cur_sum); + + matrix P; + Transpose2DMatrix(St, P); + + prefetch_V.set_base_ptr((reinterpret_cast(v_cache_base)+prefetch_block_id*blk_stride)); + prefetch_V.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + + b2dV.set_base_ptr((reinterpret_cast(v_cache_base)+cur_block_id*blk_stride)); + b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + if (kv_pos == 0) { + // ugemm_PV0(slm_V, P, rO, slm_offset); + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_prefetch(prefetch_V.set_block_x(k)); + cm_load(Vmat.format(), b2dV.set_block_x(k)); + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + 0, + Vmat.format(), + P2.row(p).format()); + } + } + } + else { + //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + + cm_prefetch(prefetch_V.set_block_x(k)); + cm_load(Vmat.format(), b2dV.set_block_x(k)); + + //# compensate cur_O + // matrix rO; + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < REG_M; r++) + cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); + } + + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + rO[ri + p].format(), + Vmat.format(), + P2.row(p).format()); + } + } + } + } + if (q_tokens_left == 0) return; + + //# save cur_O/cur_sum.transpose(0, 1) + matrix cur_O_f16; + cur_sum = cm_inv(cur_sum); + + lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); + + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < cO.n_rows(); r++) { + cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + + } + } + b2dO.set_block_x(k); + cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); + cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); + } +} + +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp index c8a3ed83921bf8..402cacb2e77674 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp @@ -13,407 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include -#include - -//# CM-compiler is C++17 -static_assert(__cplusplus >= 201703L); -//# static_assert(__cplusplus >= 202002L); -//# static_assert(__cplusplus >= 202302L); - -#define SystolicDepth 8 -#define RepeatCount 8 -#define VNNI_WIDTH 2 -#define REG_K (SystolicDepth * VNNI_WIDTH) -#define REG_M RepeatCount -//REG_N -// Xe1: 8 -// Xe2: 16 -#define REG_N (CM_GRF_WIDTH/32) - -#define kv_step REG_K -#define q_step REG_N - -constexpr float scale_factor = CMFLA_SCALE_FACTOR; - -static_assert(q_step == 16 || q_step == 8); -static_assert(kv_step == 16); -static_assert(CM_HAS_DPAS); - -#define DEBUG_SHOW 1 -#if !DEBUG_SHOW -template -void show(const matrix mat, bool isfloat=true) { -} -#else -template -void show(const matrix mat, bool isfloat=true) { - printf("Matrix [%d, %d]:\n", M, N); - for(int m = 0; m < M; m ++) { - printf("\t["); - for(int n = 0; n < N; n ++) { - if (isfloat) - printf("%8.4f,", mat[m][n]); - else - printf("%8d,", mat[m][n]); - - } - printf("],\n"); - } - printf("]\n"); -} -#endif -template -CM_INLINE void Transpose_16x16(matrix_ref in, - matrix_ref out) { - matrix bBuf; - bBuf.row(0) = in.template select<4, 1, 4, 4>(0, 0); // 0,4,8,c - bBuf.row(1) = in.template select<4, 1, 4, 4>(4, 0); // 0,4,8,c - bBuf.row(2) = in.template select<4, 1, 4, 4>(8, 0); // 0,4,8,c - bBuf.row(3) = in.template select<4, 1, 4, 4>(12, 0); // 0,4,8,c - bBuf.row(4) = in.template select<4, 1, 4, 4>(0, 1); // 1,5,9,d - bBuf.row(5) = in.template select<4, 1, 4, 4>(4, 1); // 1,5,9,d - bBuf.row(6) = in.template select<4, 1, 4, 4>(8, 1); // 1,5,9,d - bBuf.row(7) = in.template select<4, 1, 4, 4>(12, 1); // 1,5,9,d - bBuf.row(8) = in.template select<4, 1, 4, 4>(0, 2); // 2,6,a,e - bBuf.row(9) = in.template select<4, 1, 4, 4>(4, 2); // 2,6,a,e - bBuf.row(10) = in.template select<4, 1, 4, 4>(8, 2); // 2,6,a,e - bBuf.row(11) = in.template select<4, 1, 4, 4>(12, 2); // 2,6,a,e - bBuf.row(12) = in.template select<4, 1, 4, 4>(0, 3); // 3,7,b,f - bBuf.row(13) = in.template select<4, 1, 4, 4>(4, 3); // 3,7,b,f - bBuf.row(14) = in.template select<4, 1, 4, 4>(8, 3); // 3,7,b,f - bBuf.row(15) = in.template select<4, 1, 4, 4>(12, 3); // 3,7,b,f - - out.row(0) = bBuf.template select<4, 1, 4, 4>(0, 0); // 0 - out.row(1) = bBuf.template select<4, 1, 4, 4>(4, 0); // 1 - out.row(2) = bBuf.template select<4, 1, 4, 4>(8, 0); // 2 - out.row(3) = bBuf.template select<4, 1, 4, 4>(12, 0); // 3 - out.row(4) = bBuf.template select<4, 1, 4, 4>(0, 1); // 4 - out.row(5) = bBuf.template select<4, 1, 4, 4>(4, 1); // 5 - out.row(6) = bBuf.template select<4, 1, 4, 4>(8, 1); // 6 - out.row(7) = bBuf.template select<4, 1, 4, 4>(12, 1); // 7 - out.row(8) = bBuf.template select<4, 1, 4, 4>(0, 2); // 8 - out.row(9) = bBuf.template select<4, 1, 4, 4>(4, 2); // 9 - out.row(10) = bBuf.template select<4, 1, 4, 4>(8, 2); // a - out.row(11) = bBuf.template select<4, 1, 4, 4>(12, 2); // b - out.row(12) = bBuf.template select<4, 1, 4, 4>(0, 3); // c - out.row(13) = bBuf.template select<4, 1, 4, 4>(4, 3); // d - out.row(14) = bBuf.template select<4, 1, 4, 4>(8, 3); // e - out.row(15) = bBuf.template select<4, 1, 4, 4>(12, 3); // f -} - -template -CM_INLINE void Transpose_8x8(matrix_ref in, matrix_ref out) { - matrix temp; - temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); - temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); - temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); - temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); - temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); - temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); - temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); - temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); - - out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); - out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); - out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); - out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); - out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); - out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); - out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); - out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); -} - -// function templates cannot be partially specialized; use overloading to achieve the same effect -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_8x8(in, out); -} -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_16x16(in, out); -} -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); - Transpose_8x8(in.select<8, 1, 8, 1>(8,0), out.select<8, 1, 8, 1>(0,8)); -} -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); - Transpose_8x8(in.select<8, 1, 8, 1>(0,8), out.select<8, 1, 8, 1>(8,0)); -} - -template -CM_INLINE void slm_read_2d(matrix_ref out, uint slm, int offset) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_slm_block_read(slm, GENX_DWALIGNED, offset + i*n_stride*sizeof(T), out.row(i)); - } -} - -template -CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_svm_block_read(base + i * pitch, out[i]); - } -} - -template -CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - out.row(i).format() = cm_load(base, offset + i * pitch); - } -} - -template -CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - out.row(i).format() = cm_load(base, offset + i * pitch); - } -} - -template -CM_INLINE void cm_store_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_store(base, offset + i * pitch, out.row(i).format()); - } -} - -template -CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, vector_ref offsets) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_svm_block_read(base + offsets[i], out[i]); - } -} - -template -CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++, base += pitch, n_rows--) { - if (n_rows > 0) cm_svm_block_read(base, out[i]); - } -} - -template -CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++, base += pitch) { - cm_svm_block_write(base, out[i]); - } -} - -template -CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++, base += pitch) { - if (i < n_rows) cm_svm_block_write(base, out[i]); - } -} - -CM_INLINE uint64_t get_clock() { - auto clk = cm_clock(); - return ((uint64_t)clk[1]) << 32 | clk[0]; -} - - -template -inline matrix ugemm_KQ(uint slm_K, matrix_ref Qt, uint slm_offset = 0) { - matrix St; - constexpr int num_K = _kv_step/REG_M; - auto St2 = St.format(); - - matrix Kmat; - cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); - // if (cm_local_id(2) == 3 && cm_group_id(2) == 0) { - // show(Kmat.format()); - // } - #pragma unroll - for(int k = 0; k < num_K; k++) - St2.row(k) = cm_dpas(0, Qt[0].format(), Kmat[k].format()); - - #pragma unroll - for(int ri = 1; ri < num_Qt; ri++) { - cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); - #pragma unroll - for(int k = 0; k < num_K; k++) { - St2.row(k) = cm_dpas(St2.row(k), Qt[ri].format(), Kmat[k].format()); - } - } - return St; -} - -template -inline void ugemm_PV0(uint slm_V, matrix_ref P, matrix_ref rO, uint slm_offset = 0) { - constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; - - auto P2 = P.format(); - #pragma unroll - for(int k = 0, ri = 0; k < _head_size; k += REG_N, ri += num_P_tiles) { - matrix Vmat; - cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); - // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { - // show(Vmat.format()); - // } - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - rO[ri + p] = cm_dpas( - 0, - Vmat.format(), - P2.row(p).format()); - //show(rO[ri + p].format()); - } - } -} - -template -inline void ugemm_PV1(uint slm_V, matrix_ref P, vector_ref max_comp, - matrix_ref rO, uint slm_offset = 0) { - constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; - auto P2 = P.format(); - #pragma unroll - for(int k = 0, ri=0; k < _head_size; k += REG_N, ri += num_P_tiles) { - matrix Vmat; - - cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); - // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { - // show(Vmat.format()); - // } - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - auto cO = rO[ri + p].format(); - #pragma unroll - for(int r = 0; r < REG_M; r++) - cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); - } - - //show(rO[ri].format()); - - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - rO[ri + p] = cm_dpas( - rO[ri + p].format(), - Vmat.format(), - P2.row(p).format()); - //if (kv_pos == args_verbose) show(rO[ri + p].format()); - } - // if (kv_pos == args_verbose) show(cur_O.format()); - } -} - -template -vector online_softmax_update(matrix_ref St, vector_ref cur_max, vector_ref cur_sum) { - vector new_max_t; - new_max_t = cm_max(St[0], St[1]); - for(int r = 2; r < St.n_rows(); r++) new_max_t = cm_max(new_max_t, St[r]); - new_max_t = cm_max(new_max_t, cur_max); - - // Pt = torch.exp(St - new_max) - constexpr float log2e = 1.4426950408889634f; - for(int r = 0; r < St.n_rows(); r++) St[r] = cm_exp((St[r] - new_max_t)*log2e); - - vector row_sum_t; - row_sum_t = cm_add(St[0], St[1]); - for(int r = 2; r < St.n_rows(); r++) row_sum_t = cm_add(row_sum_t, St[r]); - - vector max_comp; - max_comp = cm_exp((cur_max - new_max_t)*log2e); - cur_sum = cm_mul(cur_sum, max_comp); - cur_sum = cm_add(cur_sum, row_sum_t); - cur_max = new_max_t; - return max_comp; -} +#include "cm_attention_common.hpp" #ifdef CM_HAS_LSC_UNTYPED_2D - #define cm_load_normal cm_load - #define cm_load_transpose cm_load - #define cm_load_vnni cm_load - #define cm_store_normal cm_store -#else - // simulation of LSC API using SVM API - template - inline void cm_load_normal(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { - static_assert(NBlocks == 1); - auto pitch = Desc.get_pitch() + 1; - auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); - #pragma unroll - for(int i = 0; i < BlockH; i++) { - cm_svm_block_read(base + i * pitch, Res.select(i*BlockW)); - } - } - - template - inline void cm_load_transpose(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { - static_assert(NBlocks == 1); - auto pitch = Desc.get_pitch() + 1; - auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); - matrix temp; - #pragma unroll - for(int i = 0; i < BlockH; i++) { - cm_svm_block_read(base + i * pitch, temp[i]); - } - Transpose2DMatrix(temp, Res.format()); - } - - // in VNNI case, NBlocks is increasing along X dimension (increase cache-line usage) - template - inline void cm_load_vnni(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { - static_assert(NBlocks == 1 || NBlocks == 2); - // each block must be a full XMX B matrix - static_assert(BlockH == REG_K); - static_assert(BlockW == REG_N); - auto pitch = Desc.get_pitch() + 1; - auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); - matrix temp; - #pragma unroll - for(int i = 0; i < BlockH; i++) { - cm_svm_block_read(base + i * pitch, temp[i]); - } - - auto out_vnni = Res.format(); - #pragma unroll - for(int i = 0; i < NBlocks; i ++) { - out_vnni.select(i*(BlockH/2), 0) = temp.select(0, i*BlockW); - out_vnni.select(i*(BlockH/2), 1) = temp.select(1, i*BlockW); - } - } - - template - inline void cm_store_normal(const lsc::block_2d_desc &Desc, vector_ref Res) { - static_assert(NBlocks == 1); - auto pitch = Desc.get_pitch() + 1; - auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); - #pragma unroll - for(int i = 0; i < BlockH; i++) { - cm_svm_block_write(base + i * pitch, Res.select(i*BlockW)); - } - } -#endif - -//=============================================================================================== -template -constexpr void apply_causal_mask(matrix_ref St) { - if constexpr (i < N) { - St.row(i).select(0) = -3.4e38f; - apply_causal_mask(St); - } -} - -//prepack [K, N] to [K/2, N, 2] layout. -template -inline void prepackAsVNNIWidth2(matrix_ref input, matrix_ref out) { - #pragma unroll - for (int r = 0; r < K/2; r++) { - out.row(r).select(0) = input.row(r*2); - out.row(r).select(1) = input.row(r*2+1); - } -} - //@prefetch_u8 would have duplicated decompress perf issue. comments out for now. // template // void sdpa_kernel_lsc_prefetch_u8( @@ -663,9 +265,8 @@ inline void prepackAsVNNIWidth2(matrix_ref input, matrix_ref -void pa_lsc_u8( +template +void sdpa_kernel_lsc( uint slm_K, uint slm_V, int wg_local_id, @@ -675,20 +276,13 @@ void pa_lsc_u8( int q_len, int kv_len, svmptr_t q_base [[type("svmptr_t")]], - svmptr_t k_cache_base [[type("svmptr_t")]], - svmptr_t v_cache_base [[type("svmptr_t")]], -#if SPARSE_BLOCK_SIZE > 1 - svmptr_t sparse_mask_base [[type("svmptr_t")]], - svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], -#endif - svmptr_t o_base [[type("svmptr_t")]], - int32_t past_lens, - int32_t* block_indices [[type("svmptr_t")]]) { + svmptr_t k_base [[type("svmptr_t")]], + svmptr_t v_base [[type("svmptr_t")]], + svmptr_t o_base [[type("svmptr_t")]]) { constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); - constexpr uint q_pitch = is_q_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; - //[block_num, kv_heads, block_size, head_size] - constexpr uint kv_pitch = head_size * sizeof(uint8_t); + constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + constexpr uint kv_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); vector cur_max; vector cur_sum; @@ -715,152 +309,71 @@ void pa_lsc_u8( } } - lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); - lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); - constexpr int quan_blk_stride = CMFLA_NUM_KV_HEADS * (CMFLA_HEAD_SIZE+4) * CMPA_BLOCK_SZ * sizeof(uint8_t); - int causal_left = q_start+past_lens; + lsc::block_2d_desc b2dK(k_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); + + int causal_left = q_start; constexpr uint slm_buff_size = kv_step * head_size * sizeof(half); int slm_buff_id_write = 0; int slm_buff_id_read = 0; -#if SPARSE_BLOCK_SIZE > 1 - auto skip_compute = [&](int kv_pos) { - auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; - bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); - - return !sparse_mask; - }; - auto skip_load = [&](int kv_pos) { - auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; - bool sparse_mask = *(reinterpret_cast(wg_sparse_mask_base) + kv_start_block); - return !sparse_mask; - }; -#endif - auto load_slm_KV = [&](int kv_pos) { if (kv_pos < kv_stop) { -#if SPARSE_BLOCK_SIZE > 1 - if (skip_load(kv_pos)) { - slm_buff_id_write++; - return; - } -#endif - auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; - uint32_t dscale_offset = cur_block_id*quan_blk_stride + \ - CMPA_BLOCK_SZ * head_size * sizeof(uint8_t) + kv_pos%CMPA_BLOCK_SZ*sizeof(half); - uint slm_offset = (slm_buff_id_write & 3) * slm_buff_size; - vector dscale; - vector zp; - int kv_left = (kv_stop-kv_pos) > kv_step ? kv_step: (kv_stop-kv_pos); - slm_buff_id_write ++; if (wg_local_id < local_size/2) { - cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset), dscale); - cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset + CMPA_BLOCK_SZ*sizeof(half)), zp); - - matrix kmat; - auto quanKmat = kmat.format()[1].format(); - b2dK.set_base_ptr(reinterpret_cast(k_cache_base+cur_block_id*quan_blk_stride)); - b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); - + vector temp0; + b2dK.set_block_y(kv_pos); for(int k = REG_K*wg_local_id; k < head_size; k += REG_K*(local_size/2)) { - cm_load(quanKmat.format(), b2dK.set_block_x(k)); - /*@bug: cm compiler in the tail process. - : loop combined with type convert. - for(int r = 0; r < kv_left; r++) { - kmat[r] = quanKmat[r]-zp[r]; - kmat[r] = cm_mul(kmat[r], dscale[r]); - } - wa: unroll all kv_step rows. set 0 to padding rows. - */ - #pragma unroll - for(int r = 0; r < kv_step; r++) { - kmat[r] = quanKmat[r]-zp[r]; - kmat[r] = cm_mul(kmat[r], dscale[r]); - } - //clear unused data to 0. - for(int r = kv_step-1; r >= kv_left; r--) - kmat[r] = 0; - cm_slm_block_write(slm_K, slm_offset + k * kv_step * sizeof(half), kmat.format()); + cm_load(temp0, b2dK.set_block_x(k)); + cm_slm_block_write(slm_K, slm_offset + k * kv_step * sizeof(half), temp0); } } else { - cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset), dscale); - cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset+CMPA_BLOCK_SZ*sizeof(half)), zp); - - matrix VmatVNNI; - matrix Vmat; - auto quanVmat = Vmat.format().row(1).format(); - b2dV.set_base_ptr(reinterpret_cast(v_cache_base+cur_block_id*quan_blk_stride)); - b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); - + vector temp2; + b2dV.set_block_y(kv_pos); #pragma unroll for(int k = REG_N*(wg_local_id-(local_size/2)); k < head_size; k += REG_N*(local_size/2)) { - cm_load(quanVmat.format(), b2dV.set_block_x(k)); - /*@bug: cm compiler in the tail process. - : loop combined with type convert. - for(int r = 0; r < kv_left; r++) { - Vmat[r] = quanVmat[r]-zp[r]; - Vmat[r] = cm_mul(Vmat[r], dscale[r]); - } - */ - #pragma unroll - for(int r = 0; r < kv_step;r++) { - Vmat[r] = quanVmat[r]-zp[r]; - Vmat[r] = cm_mul(Vmat[r], dscale[r]); - } - - for(int r = kv_step-1; r>=kv_left;r--) { - Vmat[r] = 0; - } - prepackAsVNNIWidth2(Vmat, VmatVNNI); - cm_slm_block_write(slm_V, slm_offset + k * REG_K * sizeof(half), VmatVNNI.format()); + cm_load(temp2, b2dV.set_block_x(k)); + cm_slm_block_write(slm_V, slm_offset + k * REG_K * sizeof(half), temp2); } } } }; - load_slm_KV(0); load_slm_KV(kv_step); cm_slm_fence(CM_LOCAL_BARRIER); cm_sbarrier(1); - for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step,slm_buff_id_read++) { + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step, + k_base += kv_step * kv_pitch, + v_base += kv_step * kv_pitch, + slm_buff_id_read ++) { - // load0, load1, signal1, - // [wait1, signal2, load2, read0, compute0] - // [wait2, signal3, load3, read1, compute1] - // [wait3, signal4, load4, read2, compute2] - // [wait4, signal5, load5, read3, compute3] + // load0, load1, signal1, + // [wait2, signal2, load2, read0] + // [wait3, signal3, load3, read1] + // [wait4, signal4, load4, read2] + // [wait5, signal5, load5, read3] // - // after wait3, all workers have reached signal3, so: - // - all workers have finished load2 & read0. - // - we can start to load 4 into SLM slot 0 (i & 3) safely + // after wait4, all workers have reached signal3, so: + // - all workers have finished load2 & read0. + // - we can start to load 4 into SLM slot 0 (i & 3) safely // - we can start to read 2 ((i-2) & 3) safely - cm_fence(CM_LOCAL_BARRIER); cm_sbarrier(0); - //if (kv_pos > 1024000) + //if (kv_pos > 1024000) // for debugging if (kv_pos + kv_step < kv_stop) cm_sbarrier(1); - load_slm_KV(kv_pos + kv_step*2); + load_slm_KV(kv_pos + kv_step*2); -#if SPARSE_BLOCK_SIZE > 1 - if (skip_compute(kv_pos)) { - if constexpr (use_causal_mask) - causal_left -= kv_step; - continue; - } -#endif { - uint slm_offset = (slm_buff_id_read & 3) * slm_buff_size; - //# St = k @ Qt matrix St = ugemm_KQ(slm_K, rQ, slm_offset); + if constexpr (use_causal_mask) { // since kv_step == q_step == 16, causal_left is n*kv_step if (causal_left == 0) { @@ -874,6 +387,8 @@ void pa_lsc_u8( // LSC ensures no overflow-access, but mask off k-tails attn-score is still required for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; } + + //show(St); auto max_comp = online_softmax_update(St, cur_max, cur_sum); matrix P; @@ -910,32 +425,21 @@ void pa_lsc_u8( } } -#else - template -void pa_kernel_lsc_prefetch_f16( +void sdpa_kernel_lsc_prefetch( int wg_local_id, int q_start, - int kv_stop, // - int q_len, //q_step - int kv_len, //not used for now + int kv_stop, + int q_len, + int kv_len, svmptr_t q_base [[type("svmptr_t")]], - svmptr_t k_cache_base [[type("svmptr_t")]], - svmptr_t v_cache_base [[type("svmptr_t")]], -#if SPARSE_BLOCK_SIZE > 1 - svmptr_t sparse_mask_base [[type("svmptr_t")]], - svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], -#endif - svmptr_t o_base [[type("svmptr_t")]], - int32_t past_lens, - int32_t* block_indices [[type("svmptr_t")]]) { + svmptr_t k_base [[type("svmptr_t")]], + svmptr_t v_base [[type("svmptr_t")]], + svmptr_t o_base [[type("svmptr_t")]]) { + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; - // constexpr uint k_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); - // constexpr uint v_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); - //[block_num, kv_heads, block_size, head_size] - constexpr uint k_pitch = head_size * sizeof(half); - constexpr uint v_pitch = k_pitch; + constexpr uint kv_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); vector cur_max; vector cur_sum; @@ -962,46 +466,30 @@ void pa_kernel_lsc_prefetch_f16( } } - lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); - lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + lsc::block_2d_desc b2dK(k_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); static_assert(wg_local_size == 16); - lsc::block_2d_desc prefetch_K(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); - lsc::block_2d_desc prefetch_V(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); - constexpr int blk_stride = CMFLA_NUM_KV_HEADS*CMFLA_HEAD_SIZE*CMPA_BLOCK_SZ; - int causal_left = q_start+past_lens; - - for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) { - auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; - //For the last step, duplicate prefetch here. - uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); - auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; + lsc::block_2d_desc prefetch_K(k_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); + lsc::block_2d_desc prefetch_V(v_base, kv_stop - 1, head_size*sizeof(half) - 1, kv_pitch - 1, 0, 0); + + int causal_left = q_start; + + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step, + k_base += kv_step * kv_pitch, + v_base += kv_step * kv_pitch) { //# St = k @ Qt - matrix St; + matrix St; // = ugemm_KQ(slm_K, rQ, slm_offset); { constexpr int num_K = kv_step/REG_M; auto St2 = St.format(); matrix Kmat; - - prefetch_K.set_base_ptr((reinterpret_cast(k_cache_base)+prefetch_block_id*blk_stride)); - prefetch_K.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + //cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + prefetch_K.set_block_y(wg_local_id + kv_pos + kv_step); cm_prefetch(prefetch_K.set_block_x(0)); -#if SPARSE_BLOCK_SIZE > 1 - { - auto kv_start_block = kv_pos/ SPARSE_BLOCK_SIZE; - bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); - if (!sparse_mask) { - if constexpr (use_causal_mask) { - causal_left -= kv_step; - } - continue; - } - } -#endif - b2dK.set_base_ptr((reinterpret_cast(k_cache_base)+cur_block_id*blk_stride)); - b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); + b2dK.set_block_y(kv_pos); cm_load(Kmat.format(), b2dK.set_block_x(0)); #pragma unroll for(int k = 0; k < num_K; k++) @@ -1012,6 +500,7 @@ void pa_kernel_lsc_prefetch_f16( #pragma unroll for(int ri = 1; ri < head_size/REG_K; ri++) { + //cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); cm_prefetch(prefetch_K.set_block_x(ri*REG_K)); cm_load(Kmat.format(), b2dK.set_block_x(ri*REG_K)); #pragma unroll @@ -1043,11 +532,8 @@ void pa_kernel_lsc_prefetch_f16( matrix P; Transpose2DMatrix(St, P); - prefetch_V.set_base_ptr((reinterpret_cast(v_cache_base)+prefetch_block_id*blk_stride)); - prefetch_V.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); - - b2dV.set_base_ptr((reinterpret_cast(v_cache_base)+cur_block_id*blk_stride)); - b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + b2dV.set_block_y(kv_pos); + prefetch_V.set_block_y(wg_local_id +kv_pos + kv_step); if (kv_pos == 0) { // ugemm_PV0(slm_V, P, rO, slm_offset); auto P2 = P.format(); @@ -1111,7 +597,6 @@ void pa_kernel_lsc_prefetch_f16( #pragma unroll for(int r = 0; r < cO.n_rows(); r++) { cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); - } } b2dO.set_block_x(k); @@ -1120,4 +605,208 @@ void pa_kernel_lsc_prefetch_f16( } } -#endif \ No newline at end of file +#else // CM_HAS_LSC_UNTYPED_2D + +template +void sdpa_kernel( + uint slm_K, + uint slm_V, + int wg_local_id, + int local_size, + int q_start, + int kv_stop, + int q_len, + int kv_len, + SurfaceIndex query [[type("buffer_t")]], + SurfaceIndex key [[type("buffer_t")]], + SurfaceIndex value [[type("buffer_t")]], + SurfaceIndex output [[type("buffer_t")]], + uint q_off, + uint k_off, + uint v_off, + uint o_off) { + + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + constexpr uint kv_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + + vector cur_max; + vector cur_sum; + + cur_max = -3e38f; + cur_sum = 0; + + matrix rQ; + auto q_tokens_left = q_len; + static_assert(q_step == REG_N); + static_assert(kv_step == REG_K); + + if (q_tokens_left < 0) q_tokens_left = 0; + if (q_tokens_left > q_step) q_tokens_left = q_step; + + if (q_tokens_left > 0) { + // load as many as possible given one address + if constexpr (head_size == 128 || head_size == 64) { + matrix QmatI32; + cm_load_2d(QmatI32, query, q_off, q_pitch); + #pragma unroll + for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { + Transpose2DMatrix(QmatI32.select(0, k), rQ[ri].format()); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } else { + #pragma unroll + for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { + matrix QmatI32; + cm_load_2d(QmatI32, query, q_off + k * sizeof(uint), q_pitch); + Transpose2DMatrix(QmatI32, rQ[ri].format()); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + } + + constexpr int num_P_tiles = REG_N / REG_M; + matrix rO; + int causal_left = q_start; + + constexpr uint slm_buff_size = kv_step * head_size * sizeof(half); + int slm_buff_id_write = 0; + int slm_buff_id_read = 0; + + auto load_slm_KV = [&](int kv_pos) { + //if (kv_pos < 1024000) return; + int kv_tokens = kv_stop - kv_pos; + if (kv_tokens <= 0) return; + uint slm_offset = (slm_buff_id_write & 3) * slm_buff_size; + slm_buff_id_write ++; + + // non-tail branch is faster + if (wg_local_id < local_size/2) { + //if (kv_pos > 1024000) { + matrix temp; + for(int k = REG_K * wg_local_id; k < head_size; k += REG_K*(local_size/2)) { + cm_load_2d(temp, key, k_off + k*sizeof(half), kv_pitch); + cm_slm_block_write(slm_K, + slm_offset + k * 2 * REG_M * sizeof(half), + temp.format()); + } + } else { + //if (kv_pos > 1024000) { + // read 16x16 XMX-B matrix (1x REG_N in Xe2, 2x REG_N in Xe1) + constexpr int VK_STEP = 16; + static_assert((VK_STEP % REG_N) == 0); + matrix temp2; + matrix temp_vnni; + //b2dV.set_block_y(kv_pos); + + static_assert((head_size % VK_STEP) == 0); + #pragma unroll + for(int k = VK_STEP * (wg_local_id-local_size/2); k < head_size; k += VK_STEP * (local_size/2)) { + cm_load_2d(temp2, value, v_off + k*sizeof(half), kv_pitch); + + #pragma unroll + for(int p = 0; p < VK_STEP/REG_N; p++) { + temp_vnni.select(0, 0) = temp2.select(0, p*REG_N); + temp_vnni.select(0, 1) = temp2.select(1, p*REG_N); + // show(temp_vnni); + cm_slm_block_write(slm_V, slm_offset + (k + p*REG_N) * REG_K * sizeof(half), temp_vnni.format()); + } + } + } + k_off += kv_step * kv_pitch; + v_off += kv_step * kv_pitch; + // printf(" diff= %lu\n", get_clock() - clk0); + }; + + load_slm_KV(0); + load_slm_KV(kv_step); + + cm_slm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(1); + + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step, + slm_buff_id_read ++) { + // + // load0->0, signal1, + // [load1->1, wait2, signal2, read0] + // [load2->2, wait3, signal3, read1] + // [load3->3, wait4, signal4, read2] + // [load4->0, wait5, signal5, read3] + // + // after wait4, all workers have reached signal3, so: + // - all workers have finished load2 & read0. + // - we can start to load 4 into SLM slot 0 (i & 3) safely + // - we can start to read 2 ((i-2) & 3) safely + // + cm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(0); + + load_slm_KV(kv_pos + 2*kv_step); + + if (kv_pos + kv_step < kv_stop) + cm_sbarrier(1); + + //if (kv_pos < 1024000) continue; + uint slm_offset = (slm_buff_id_read & 3) * slm_buff_size; + + //=========================================================== 1807 ~ 3247 + //# St = k @ Qt + matrix St = ugemm_KQ(slm_K, rQ, slm_offset); + + if constexpr (use_causal_mask) { + if (causal_left < kv_step) { + vector cmask = 0.0f; + int p = causal_left + 1; + int v = 0; + for(; p < 0; p++) { + cmask[v] = -3.4e38f; + if (v < q_step - 1) v++; + } + for(; p < kv_step; p++) { + cmask[v] = -3.4e38f; + St[p] = cm_add(St[p], cmask); + if (v < q_step - 1) v++; + } + //if (wg_local_id == 0) show(St);return; + } + causal_left -= kv_step; + } + + // mask off k-tails + int kv_tokens = kv_stop - kv_pos; + for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; + + //show(St); + auto max_comp = online_softmax_update(St, cur_max, cur_sum); + + matrix P; + Transpose2DMatrix(St, P); + + if (kv_pos == 0) + ugemm_PV0(slm_V, P, rO, slm_offset); + else + ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); + } + + if (q_tokens_left > 0) { + //# save cur_O/cur_sum.transpose(0, 1) + matrix cur_O_f16; + cur_sum = cm_inv(cur_sum); + + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < cO.n_rows(); r++) { + cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + } + } + // if (i == args_verbose) show(cur_O_f16); + cm_store_2d(cur_O_f16, output, o_off + k*sizeof(half), o_pitch); + } + } +} + +#endif // !CM_HAS_LSC_UNTYPED_2D \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm index 8c9993b8e8612b..cfa33451f969e6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm @@ -16,7 +16,7 @@ *******************************************************************************/ namespace KERNEL_NAME { -#include "cm_sdpa_common.hpp" +#include "cm_pa_common.hpp" #ifdef CM_HAS_LSC_UNTYPED_2D #define USE_LSC 1 From a06adeffc06b54a90f57eb63695c7182ee94f3e8 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 22 Sep 2025 16:58:23 +0800 Subject: [PATCH 23/96] integrate xattn_post_proc kernel and FP16 kernel works. TODOto verify u8 kvcache. --- .../src/graph/impls/cm/paged_attention.cpp | 8 ++ .../graph/impls/cm/paged_attention_gen.cpp | 74 ++++++++++++++++++- .../graph/impls/cm/paged_attention_gen.hpp | 8 ++ .../src/graph/impls/cm/xattn_post_proc.cm | 61 +++++++++++++++ 4 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 9e07710643671c..b246637d3e7c9e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -37,6 +37,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { Stage::Ptr pa_multi_token = make_stage(); Stage::Ptr xattn_estimate_gemmqk = make_stage(); Stage::Ptr xattn_estimate_find_block = make_stage(); + Stage::Ptr xattn_estimate_post_proc = make_stage(); PagedAttentionCmImpl(): PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) { m_rt_params = std::make_unique(); @@ -53,6 +54,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { if (xattn_block_size > 1) { add_stage(xattn_estimate_gemmqk, params); add_stage(xattn_estimate_find_block, params); + add_stage(xattn_estimate_post_proc, params); } } @@ -124,6 +126,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { pa_id++; } #endif + res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; } res_event = {execute_stage(res_event, instance, pa_multi_token)}; } else if (rt_params->stage == PagedAttentionStage::GENERATE) { @@ -202,6 +205,11 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto count_elements_mask = static_cast(desc->heads_num * q_block_pad * k_block_pad); internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask + + const uint32_t MERGED_Q_NUM = 2; // TODO + const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); + auto count_elements_mask_merged = static_cast(desc->heads_num * q_block_pad_merged * k_block_pad); + internal_buffers.emplace_back(count_elements_mask_merged, ov::element::boolean); // 5: sparse_block_mask_wg } } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 560f5809e83070..7788b30953e85e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -429,8 +429,10 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins const size_t block_size = get_xattn_block_size(params); - if (block_size > 1) + if (block_size > 1) { args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // sparse_block_mask_wg + } args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); @@ -944,4 +946,74 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { }}; } +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate post_proc generator +//----------------------------------------------------------------------------------------------------------------- +JitConstants XAttentionEstimatePostProc::get_jit_constants(const kernel_impl_params& params) const { + auto jit = XAttentionEstimateGeneratorBase::get_jit_constants(params); + + jit.make("MERGED_Q_NUM", 2); // TODO + + return jit; +} + +Arguments XAttentionEstimatePostProc::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + + // inputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // block_mask + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // block_mask_merged + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_stride_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad + + return args; +} + +DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { + return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + const auto desc = params.typed_desc(); + + assert(rt_params != nullptr); + + const size_t block_size = get_xattn_block_size(params); + const size_t heads_num = desc->heads_num; + + auto out_shape = params.output_layouts[0].get_shape(); + const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; + const size_t q_len = out_shape[0]; + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t N = static_cast(kv_len / STRIDE); + const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); + const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + + const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); + const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); + const uint32_t k_block_pad = k_block_in_group * N_kq_groups; + const uint32_t q_block_pad = ceil_div(q_len, block_size); + + const uint32_t MERGED_Q_NUM = 2; // TODO + const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); + + wgs.global = {q_block_pad_merged, heads_num, 1}; + wgs.local = {1, 1, 1}; + + auto& scalars = kd.params.scalars; + std::vector scaler_value = {q_stride_pad, q_block_pad, k_block_pad}; + scalars.resize(scaler_value.size()); + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::UINT32; + scalars[i].v.u32 = static_cast(scaler_value[i]); + } + }}; +} + } // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index fadd095e688180..378e595c1aab8f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -134,4 +134,12 @@ class XAttentionEstimateFindBlock : public XAttentionEstimateGeneratorBase { [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; }; +class XAttentionEstimatePostProc : public XAttentionEstimateGeneratorBase { +public: + XAttentionEstimatePostProc() : XAttentionEstimateGeneratorBase("xattn_post_proc") {} + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + } // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm new file mode 100644 index 00000000000000..93595c3fc83a02 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "estimate.hpp" + +// NOTE: q_stride_pad / TOKEN_IN_BLOCK <= q_block_pad, case for q_stride_pad / TOKEN_IN_BLOCK < q_block_pad: +// query = 256*16+1, then +// q_stride_pad = 256 +// q_stride_pad / TOKEN_IN_BLOCK = 32 +// q_block_pad = div_up(256*16+1, 128) = 33 +// _GENX_MAIN_ void post_proc_mask( +extern "C" _GENX_MAIN_ void KERNEL_NAME(svmptr_t block_mask ATTR, svmptr_t merged_block_mask ATTR, uint q_stride_pad, uint q_block_pad, uint k_block_pad) { + // block_mask: [b, hq, q_block_pad, k_block_pad] + // merged_block_mask: [b, hq, q_block_pad/MERGED_Q_NUM, k_block_pad] + // global: [q_block_pad/MERGED_Q_NUM, hq, b] + const int TOKEN_IN_BLOCK = BLOCK_SIZE / STRIDE; + const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; + uint m_mereged = cm_group_id(0); + uint hq = cm_group_id(1); + uint b = cm_group_id(2); + block_mask += (b * HQ + hq) * q_block_pad * k_block_pad; + merged_block_mask += (b * HQ + hq) * cm_group_count(0) * k_block_pad; + merged_block_mask += m_mereged * k_block_pad; + block_mask += m_mereged * MERGED_Q_NUM * k_block_pad; + vector one = 1; + // q is not inside mask, aka q=1~15 which is less than param `stride` + //for (int i = 0; i < MERGED_Q_NUM; i++) { + // auto q_stride_cur = m_mereged * MERGED_Q_NUM + i; + // if (q_stride_cur >= q_stride_pad / TOKEN_IN_BLOCK && q_stride_cur < q_block_pad) { + // for (int j = 0; j < k_block_pad; j += 32) { + // cm_ptr_store((int*)block_mask, j + i * k_block_pad, one.format()); + // } + // } + //} + for (int j = 0; j < k_block_pad; j += 32) { + vector new_mask = cm_ptr_load((int*)block_mask, j).format(); + for (int i = 1; i < MERGED_Q_NUM; i++) { + if (m_mereged * MERGED_Q_NUM + i < q_stride_pad / TOKEN_IN_BLOCK) { + vector cur_mask = cm_ptr_load((int*)block_mask, j + i * k_block_pad).format(); + new_mask &= cur_mask; + } + } + cm_ptr_store((int*)merged_block_mask, j, new_mask.format()); + } +} + +} // NAMESPACE From 4b391be441fb39b75c1fa51276a9511ef7f5e540 Mon Sep 17 00:00:00 2001 From: "river.li" Date: Sun, 14 Sep 2025 17:22:01 +0800 Subject: [PATCH 24/96] update partition size --- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 7788b30953e85e..3e3a19634425e8 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -174,7 +174,11 @@ size_t get_partition_size() { // k_partition_blok_num = 1; // const size_t k_partition_blok_num = 16; // return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; // 128 - return 256; + if (PA_KV_CACHE_BLOCK_SIZE < 128) { + return 128; + } else { + return PA_KV_CACHE_BLOCK_SIZE; + } } size_t get_partition_num(const size_t kv_len) { From f2f21264bc9561890620ad84c809f49b8d53bf07 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Tue, 23 Sep 2025 11:26:20 +0800 Subject: [PATCH 25/96] enable int8 kvcache for xatten, but accuracy fails. --- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 3e3a19634425e8..1e7ca0a86bf2b3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -717,8 +717,13 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... jit.make("WALK_HQ", desc->heads_num != desc->kv_heads_num ? 2 : 1); jit.make("IS_CAUSAL", 1); - jit.make("USE_INT8", 0); - jit.make("HEAD_SIZE_KEY", desc->k_head_size); + if (get_kv_compressed(params)) { + jit.make("USE_INT8", 1); + jit.make("HEAD_SIZE_KEY", desc->k_head_size + 2 * 2); + } else { + jit.make("USE_INT8", 0); + jit.make("HEAD_SIZE_KEY", desc->k_head_size); + } jit.make("SOFTMAX_TYPE", "float"); // for (auto& it : jit) { From 89c8577ff2d1dfd8b464f9b7069fd541eca2d713 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 24 Sep 2025 17:19:22 +0800 Subject: [PATCH 26/96] fix xattn kvcache u8 accuracy issue. --- src/plugins/intel_gpu/src/graph/debug_helper.cpp | 2 +- src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/debug_helper.cpp b/src/plugins/intel_gpu/src/graph/debug_helper.cpp index 068570dec5e579..a1581b61ecfc3b 100644 --- a/src/plugins/intel_gpu/src/graph/debug_helper.cpp +++ b/src/plugins/intel_gpu/src/graph/debug_helper.cpp @@ -193,7 +193,7 @@ void log_memory_to_file(memory::ptr mem, layout data_layout, stream& stream, std dump(actual_mem, stream, file_stream, dump_raw); else if (mem_dt == cldnn::data_types::u8) dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::u8) + else if (mem_dt == cldnn::data_types::boolean) dump(actual_mem, stream, file_stream, dump_raw); else if (mem_dt == cldnn::data_types::i4 || mem_dt == cldnn::data_types::u4) dump_i4u4(mem_dt, actual_mem, stream, file_stream, dump_raw); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm index 93595c3fc83a02..62410533eac194 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm @@ -51,7 +51,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(svmptr_t block_mask ATTR, svmptr_t merge for (int i = 1; i < MERGED_Q_NUM; i++) { if (m_mereged * MERGED_Q_NUM + i < q_stride_pad / TOKEN_IN_BLOCK) { vector cur_mask = cm_ptr_load((int*)block_mask, j + i * k_block_pad).format(); - new_mask &= cur_mask; + new_mask |= cur_mask; } } cm_ptr_store((int*)merged_block_mask, j, new_mask.format()); From 024b71a456d7e1b19f4096f92845c30f16e75ce3 Mon Sep 17 00:00:00 2001 From: "river.li" Date: Wed, 24 Sep 2025 23:05:33 +0800 Subject: [PATCH 27/96] Fix 2nd accuracy issue --- .../src/graph/impls/cm/pa_single_token_finalization.cm | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm index a46e072100a83f..fef907e0f3fc3d 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -22,12 +22,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const int total_partition_num = (kv_partition_num * HEADS_NUM); // load lse - #if 0 - uint lse_offset = batch * total_partition_num + head * kv_partition_num; - vector lse_vec; - cm_svm_block_read((svmptr_t)(lse + lse_offset), lse_vec.format()); - float total_lse = cm_sum(lse_vec); - #else float total_lse = 0.0; uint lse_offset = batch * total_partition_num + head * kv_partition_num; constexpr float log2e = 1.4426950408889634f; @@ -36,12 +30,12 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( for(int k = 1; k < kv_partition_num; k ++) { lse_max = cm_max(lse_vec[k], lse_max); } + float lse_value = 0.0; #pragma unroll for(int k = 0; k < kv_partition_num; k ++) { - float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); + lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); total_lse += lse_value; } - #endif // load input, total_partition_num = head_nums * kv_partition_num; matrix out_mat_f32 = 0; From 033304f959f130e0d7d24c92be1cd5dc59336953 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 29 Sep 2025 14:58:01 +0800 Subject: [PATCH 28/96] Fix 2nd accuracy issue --- .../src/graph/impls/cm/pa_single_token_finalization.cm | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm index fef907e0f3fc3d..66d8acf6df0ef2 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -30,10 +30,8 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( for(int k = 1; k < kv_partition_num; k ++) { lse_max = cm_max(lse_vec[k], lse_max); } - float lse_value = 0.0; - #pragma unroll for(int k = 0; k < kv_partition_num; k ++) { - lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); + float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); total_lse += lse_value; } @@ -42,7 +40,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( matrix out_mat = 0; matrix data_mat; uint input_offset = batch * total_partition_num * HEAD_SIZE + head * kv_partition_num * HEAD_SIZE + offset; - #pragma unroll for(int k = 0; k < kv_partition_num; k ++) { cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); input_offset += HEAD_SIZE; From a6e72d029ba83db91cabc37ebb4a2114c2732e68 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Tue, 30 Sep 2025 15:27:53 +0800 Subject: [PATCH 29/96] fix xattn tailing issue: Q_blocks < K_blocks, as K_blocks is aligned to WGS --- .../src/graph/impls/cm/pa_multi_token.cm | 23 +++++++++--------- .../src/graph/impls/cm/paged_attention.cpp | 24 +++++++++++++++++++ .../graph/impls/cm/paged_attention_gen.cpp | 17 +++++++++---- .../graph/impls/cm/paged_attention_gen.hpp | 2 ++ 4 files changed, 51 insertions(+), 15 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm index cfa33451f969e6..d85d4b692a36ea 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm @@ -39,12 +39,16 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( int32_t* block_indices [[type("svmptr_t")]], int32_t* block_indices_begins [[type("svmptr_t")]], int32_t* subsequence_begins [[type("svmptr_t")]], + half* output [[type("svmptr_t")]], #if SPARSE_BLOCK_SIZE > 1 bool* sparse_block_mask [[type("svmptr_t")]], bool* sparse_block_mask_wg [[type("svmptr_t")]], -#endif - half* output [[type("svmptr_t")]], + int q_len, + int num_q_blocks, + int num_k_blocks) { +#else int q_len) { +#endif constexpr int is_causal = CMFLA_IS_CAUSAL; constexpr int num_heads = CMFLA_NUM_HEADS; constexpr int head_size = CMFLA_HEAD_SIZE; @@ -121,16 +125,13 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint q_offset = (q_start_sg*num_heads + h)*head_size; #if SPARSE_BLOCK_SIZE > 1 - //# sparse_block_mask [num_heads, q_blocks, kv_blocks] + //# sparse_block_mask [num_heads, num_q_blocks, num_k_blocks] + //# sparse_block_mask_wg [num_heads, wg_count_along_query, num_k_blocks] auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE; - int q_blocks = (q_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE; - int kv_blocks = (kv_seq_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE; - //[self.num_heads, q_block_num, kv_block_num] - bool* block_mask_base = sparse_block_mask + (h * q_blocks + q_start_block)*kv_blocks; - //[self.num_heads, wg_count_along_query, kv_block_num)] - bool* wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id)*kv_blocks; - // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, q_blocks, kv_blocks, sparse_block_mask, block_mask_base); -#endif + bool* block_mask_base = sparse_block_mask + (h * num_q_blocks + q_start_block) * num_k_blocks; + bool* wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id) * num_k_blocks; + // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, num_q_blocks, num_k_blocks, sparse_block_mask, block_mask_base); + #endif #if CMPA_KVCACHE_U8 uint kv_offset = hkv*(head_size+4)*pa_block_sz; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index b246637d3e7c9e..cd733911a79169 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -58,6 +58,26 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { } } + void update_xattn_rt_params(const primitive_inst& instance) { + const auto& params = *instance.get_impl_params(); + + auto out_shape = params.output_layouts[0].get_shape(); + const size_t block_size = get_xattn_block_size(params); + const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; + const size_t q_len = out_shape[0]; + const uint32_t N = kv_len / STRIDE; + const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + + const auto q_block_pad = ceil_div(q_len, block_size); + const auto sum_per_token_in_block = block_size / STRIDE; + const auto k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const auto k_block_pad = k_block_in_group * N_kq_groups; + + auto rt_params = static_cast(m_rt_params.get()); + rt_params->xattn_q_block_pad = q_block_pad; + rt_params->xattn_k_block_pad = k_block_pad; + } + void update_rt_params(const primitive_inst& instance) override { update_stages_flags(instance); if (m_rt_params == nullptr) { @@ -73,6 +93,10 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { rt_params->partition_size = get_partition_size(); rt_params->num_of_partitions = ceil_div(max_context_len, rt_params->partition_size); rt_params->stage = get_paged_attention_stage(params); + const size_t block_size = get_xattn_block_size(params); + if (block_size > 1) { + update_xattn_rt_params(instance); + } GPU_DEBUG_TRACE_DETAIL << " max_context_len: " << rt_params->max_context_len << " partition_size: " << rt_params->partition_size << " num_of_partitions: " << rt_params->num_of_partitions << ", stage: " << static_cast(rt_params->stage) << std::endl; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 1e7ca0a86bf2b3..20c2aee0622b32 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -432,15 +432,19 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block_indices_begins args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); + const size_t block_size = get_xattn_block_size(params); if (block_size > 1) { args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // sparse_block_mask_wg } - args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); - - args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len + if (block_size > 1) { + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad + } return args; } @@ -478,7 +482,7 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con auto& wgs = kd.params.workGroups; auto& scalars = kd.params.scalars; auto desc = params.typed_desc(); - // auto rtp = static_cast(rt_params); + auto rtp = static_cast(rt_params); // assert(rt_params != nullptr); const size_t heads_num = desc->heads_num; @@ -519,6 +523,11 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con } std::vector scaler_value = {q_len}; + const size_t block_size = get_xattn_block_size(params); + if (block_size > 1) { + scaler_value.push_back(rtp->xattn_q_block_pad); + scaler_value.push_back(rtp->xattn_k_block_pad); + } scalars.resize(scaler_value.size()); for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::INT32; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 378e595c1aab8f..95655ac6c135c0 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -48,6 +48,8 @@ struct PagedAttentionRuntimeParams : public ImplRuntimeParams { size_t partition_size; size_t max_context_len; size_t paged_attention_aligned_seq_len; + size_t xattn_q_block_pad; + size_t xattn_k_block_pad; }; From f7ddc68ba58e2caf85067fd5185106518e50e683 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 23 Sep 2025 20:37:41 +0800 Subject: [PATCH 30/96] decide pa block size based whether use xattntion --- .../convert_pagedattn_inputs.cpp | 13 +++++++-- src/core/src/pass/sdpa_to_paged_attention.cpp | 6 +++- .../intel_gpu/primitives/paged_attention.hpp | 3 +- .../graph/impls/cm/paged_attention_gen.cpp | 29 +++++++++++++------ .../graph/impls/cm/paged_attention_gen.hpp | 5 ++-- .../impls/ocl_v2/sdpa/paged_attention_opt.cpp | 7 +++-- .../intel_gpu/src/graph/paged_attention.cpp | 6 +++- .../src/plugin/ops/paged_attention.cpp | 6 +++- .../src/plugin/transformations_pipeline.cpp | 4 +-- 9 files changed, 58 insertions(+), 21 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp index 7935d905a3b638..ce875e473b5070 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp @@ -106,18 +106,27 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co key_cache->set_element_type(key_cache_precision); value_cache->set_element_type(value_cache_precision); bool status = false; + size_t keyCacheBlockSize, valueCacheBlockSize; + if (pa_op->get_rt_info().count("k_block_size") && pa_op->get_rt_info().count("v_block_size")){ + keyCacheBlockSize = pa_op->get_rt_info()["k_block_size"].as(); + valueCacheBlockSize = pa_op->get_rt_info()["v_block_size"].as(); + } else { + keyCacheBlockSize = m_config.keyCacheBlockSize; + valueCacheBlockSize = m_config.valueCacheBlockSize; + } + if (pa_op->get_rt_info().count("num_k_heads") && pa_op->get_rt_info().count("k_head_size") && pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("v_head_size")) { const auto key_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_k_heads"].as(), pa_op->get_rt_info()["k_head_size"].as(), - m_config.keyCacheBlockSize, + keyCacheBlockSize, key_cache_precision, m_config.keyCacheGroupSize, m_config.keyCacheQuantBychannel, m_config.keyCacheDimOrder); const auto value_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_v_heads"].as(), pa_op->get_rt_info()["v_head_size"].as(), - m_config.valueCacheBlockSize, + valueCacheBlockSize, value_cache_precision, m_config.valueCacheGroupSize, m_config.valueCacheQuantBychannel, diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index f53625a4482334..1276348511f9be 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -29,7 +29,11 @@ ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_in m_use_score_outputs(use_score_outputs), m_allow_score_aggregation(use_score_outputs), m_allow_cache_rotation(allow_cache_rotation), - m_allow_xattention(allow_xattention) {} + m_allow_xattention(allow_xattention) { + if (!allow_xattention) { + setenv("OV_GPU_XATTN_BLOCK_SIZE", "1", 1); + } + } static std::shared_ptr setName(std::shared_ptr node, const char* name) { // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index 962e90dcf3ffd1..2d5a97c49f7166 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -38,7 +38,8 @@ struct paged_attention : public primitive_base { XATTENTION_STRIDE = 19, }; - static constexpr size_t block_size = 256; + static constexpr size_t block_size = 16; + static constexpr size_t block_size_xattn = 256; paged_attention() : primitive_base("", {}) {} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 1e7ca0a86bf2b3..ef005f05777c4f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -72,10 +72,6 @@ inline size_t get_input_kv_len(const RuntimeParams& params) { return kv_len; } -inline size_t get_aligned_kv_len(const size_t kv_len) { - return (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE * PA_KV_CACHE_BLOCK_SIZE; -} - inline bool get_kv_compressed(const RuntimeParams& params) { auto key_cache_layout = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE]; if (data_type_traits::is_i8_u8(key_cache_layout.data_type)) { @@ -154,7 +150,9 @@ int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAtt int64_t aligned_seq_len = 0; if (stage == PagedAttentionStage::PREFILL) { const auto desc = impl_param.typed_desc(); - if (static_cast(paged_attention::block_size) == target_seq_len_block_size) { + int64_t pa_block_size = paged_attention::block_size; + if (desc->has_xattention) pa_block_size = paged_attention::block_size_xattn; + if (static_cast(pa_block_size) == target_seq_len_block_size) { const auto& block_indices_ps = impl_param.get_input_layout(PagedAttentionInputIdx::BLOCK_INDICES).get_partial_shape(); aligned_seq_len = block_indices_ps[0].get_length() * target_seq_len_block_size; @@ -174,6 +172,7 @@ size_t get_partition_size() { // k_partition_blok_num = 1; // const size_t k_partition_blok_num = 16; // return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; // 128 + // TODO: how to change PA_KV_CACHE_BLOCK_SIZE here if (PA_KV_CACHE_BLOCK_SIZE < 128) { return 128; } else { @@ -287,7 +286,11 @@ JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kerne jit.make("KV_HEADS_NUM", desc->kv_heads_num); jit.make("K_HEAD_SIZE", desc->k_head_size); jit.make("V_HEAD_SIZE", desc->v_head_size); - jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + if (desc->has_xattention) { + jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE_XATTN); + } else { + jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + } if (get_kv_compressed(params)) { jit.make("KV_CACHE_COMPRESSION_PER_TOKEN", 1); @@ -457,7 +460,11 @@ JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_i jit.make("CMFLA_HEAD_SIZE", desc->k_head_size); jit.add(make_jit_constant("CMFLA_SCALE_FACTOR", scale_factor)); jit.make("CMFLA_IS_CAUSAL", 1); - jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE); + if (desc->has_xattention) { + jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE_XATTN); + } else { + jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE); + } jit.make("SPARSE_BLOCK_SIZE", xattn_block_size); jit.make("Q_STEP", get_q_step(xe_arch, true)); @@ -539,7 +546,11 @@ JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_ auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; jit.make("KV_PARTITION_SIZE", kv_partition_size); - jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + if (desc->has_xattention) { + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE_XATTN); + } else { + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + } jit.add(make_jit_constant("SCALE_FACTOR", scale_factor)); jit.make("HEAD_SIZE", desc->k_head_size); jit.make("HEADS_NUM", desc->heads_num); @@ -711,7 +722,7 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp jit.make("BLOCK_SG_M", BLOCK_SG_M); jit.make("BLOCK_SG_N", BLOCK_SG_N); jit.make("BLOCK_SIZE", get_xattn_block_size(params)); - jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE_XATTN); jit.add(make_jit_constant("INV_S", scale_factor_i)); jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 378e595c1aab8f..d89633e7498153 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -31,7 +31,8 @@ constexpr auto get_pa_build_options() { } // BLOCK_SIZE can be 16/32/64/128/256 -#define PA_KV_CACHE_BLOCK_SIZE 256 +#define PA_KV_CACHE_BLOCK_SIZE 16 +#define PA_KV_CACHE_BLOCK_SIZE_XATTN 256 constexpr uint32_t BLOCK_SG_M = 64; constexpr uint32_t BLOCK_SG_N = 32; @@ -64,7 +65,7 @@ size_t get_partition_num(const size_t kv_len); const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx); inline size_t get_xattn_block_size(const kernel_impl_params& impl_param) { return impl_param.get_program().get_config().get_xattention_block_size(); - } +} class PagedAttentionGeneratorBase : public KernelGenerator { public: diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp index 200c9a31e00398..62ebcbece1d81c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp @@ -195,7 +195,9 @@ static int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const P int64_t aligned_seq_len = 0; if (stage == PagedAttentionStage::PREFILL) { const auto desc = impl_param.typed_desc(); - if (static_cast(paged_attention::block_size) == target_seq_len_block_size) { + int64_t pa_block_size = paged_attention::block_size; + if (desc->has_xattention) pa_block_size = paged_attention::block_size_xattn; + if (static_cast(pa_block_size) == target_seq_len_block_size) { const auto& block_indices_ps = impl_param.get_input_layout(PagedAttentionInputIdx::BLOCK_INDICES).get_partial_shape(); aligned_seq_len = block_indices_ps[0].get_length() * target_seq_len_block_size; @@ -1576,7 +1578,8 @@ class PagedAttentionOptImpl : public SDPAImplBase { size_t index = 0; size_t micro_sdpa_index = 0; size_t subsequence_offsets_acc = 0; - const auto pa_block_size = static_cast(paged_attention::block_size); + int pa_block_size = paged_attention::block_size; + if (desc->has_xattention) pa_block_size = paged_attention::block_size_xattn; for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { const auto past_len = past_lens_mem_lock[i]; const auto seq_start = subsequence_begins_mem_lock[i]; diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 1478eed6c0d8c9..6d4e452fe60c48 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -12,6 +12,7 @@ namespace cldnn { GPU_DEFINE_PRIMITIVE_TYPE_ID(paged_attention) constexpr size_t paged_attention::block_size; +constexpr size_t paged_attention::block_size_xattn; layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*node*/, kernel_impl_params const& impl_param) { auto out_layout = impl_param.get_input_layout(0); @@ -39,7 +40,10 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no const auto& key_cache_quant_mode = impl_param.get_program().get_config().get_key_cache_quant_mode(); bool key_cache_compressed = impl_param.get_input_layout(key_cache_idx).data_type == ov::element::i8 || impl_param.get_input_layout(key_cache_idx).data_type == ov::element::u8; - auto expected_block_size = paged_attention::block_size; + size_t expected_block_size = paged_attention::block_size; + if (desc->has_xattention) { + expected_block_size = paged_attention::block_size_xattn; + } if (key_cache_compressed && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { expected_block_size += 4; } diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 61bbb016c65dc0..586cf542822bc6 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -94,10 +94,14 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared const size_t xattention_threshold_idx = cldnn::paged_attention::PagedAttentionInputIdx::XATTENTION_THRESHOLD; auto xattention_threshold_input = ov::as_type_ptr(op->get_input_node_shared_ptr(xattention_threshold_idx)); if (xattention_threshold_input && xattention_threshold_input->get_output_partial_shape(0).is_dynamic()) { + // TODO: enable xattention_threshold_input prim.has_xattention = true; + } else if(rt_info.find("k_block_size") != rt_info.end()) { + if(rt_info.at("k_block_size").as() == 256) { + prim.has_xattention = true; + } } - prim.is_key_by_channel = p.get_config().get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL; prim.num_outputs = 1; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index b9e1bb8d66e460..e51f9c1b180b74 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -512,11 +512,11 @@ void TransformationsPipeline::apply(std::shared_ptr func) { kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); kv_cache_config.inferencePrecision = infer_precision; - kv_cache_config.keyCacheBlockSize = 256; + kv_cache_config.keyCacheBlockSize = 16; kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; - kv_cache_config.valueCacheBlockSize = 256; + kv_cache_config.valueCacheBlockSize = 16; kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; kv_cache_config.valueCacheQuantBychannel = false; kv_cache_config.valueCacheGroupSize = 0; From 29cdabb7e31852d5813e8e4844c35c6c501cbe32 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Thu, 25 Sep 2025 10:17:56 +0800 Subject: [PATCH 31/96] fix bloxk size logic --- .../convert_pagedattn_inputs.cpp | 13 ++-------- src/core/src/pass/sdpa_to_paged_attention.cpp | 6 +---- .../graph/impls/cm/paged_attention_gen.cpp | 10 +++++-- .../intel_gpu/src/graph/paged_attention.cpp | 7 ++--- .../src/plugin/ops/paged_attention.cpp | 6 ++--- .../src/plugin/transformations_pipeline.cpp | 26 ++++++++++++++++--- 6 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp index ce875e473b5070..7935d905a3b638 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp @@ -106,27 +106,18 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co key_cache->set_element_type(key_cache_precision); value_cache->set_element_type(value_cache_precision); bool status = false; - size_t keyCacheBlockSize, valueCacheBlockSize; - if (pa_op->get_rt_info().count("k_block_size") && pa_op->get_rt_info().count("v_block_size")){ - keyCacheBlockSize = pa_op->get_rt_info()["k_block_size"].as(); - valueCacheBlockSize = pa_op->get_rt_info()["v_block_size"].as(); - } else { - keyCacheBlockSize = m_config.keyCacheBlockSize; - valueCacheBlockSize = m_config.valueCacheBlockSize; - } - if (pa_op->get_rt_info().count("num_k_heads") && pa_op->get_rt_info().count("k_head_size") && pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("v_head_size")) { const auto key_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_k_heads"].as(), pa_op->get_rt_info()["k_head_size"].as(), - keyCacheBlockSize, + m_config.keyCacheBlockSize, key_cache_precision, m_config.keyCacheGroupSize, m_config.keyCacheQuantBychannel, m_config.keyCacheDimOrder); const auto value_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_v_heads"].as(), pa_op->get_rt_info()["v_head_size"].as(), - valueCacheBlockSize, + m_config.valueCacheBlockSize, value_cache_precision, m_config.valueCacheGroupSize, m_config.valueCacheQuantBychannel, diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index 1276348511f9be..f53625a4482334 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -29,11 +29,7 @@ ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_in m_use_score_outputs(use_score_outputs), m_allow_score_aggregation(use_score_outputs), m_allow_cache_rotation(allow_cache_rotation), - m_allow_xattention(allow_xattention) { - if (!allow_xattention) { - setenv("OV_GPU_XATTN_BLOCK_SIZE", "1", 1); - } - } + m_allow_xattention(allow_xattention) {} static std::shared_ptr setName(std::shared_ptr node, const char* name) { // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index ef005f05777c4f..c3062cd9a5c75a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -72,6 +72,11 @@ inline size_t get_input_kv_len(const RuntimeParams& params) { return kv_len; } +inline size_t get_aligned_kv_len(const size_t kv_len) { + // TODO: how to change PA_KV_CACHE_BLOCK_SIZE here + return (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE * PA_KV_CACHE_BLOCK_SIZE; +} + inline bool get_kv_compressed(const RuntimeParams& params) { auto key_cache_layout = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE]; if (data_type_traits::is_i8_u8(key_cache_layout.data_type)) { @@ -436,7 +441,7 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins const size_t block_size = get_xattn_block_size(params); - if (block_size > 1) { + if (desc->has_xattention && block_size > 1) { args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // sparse_block_mask_wg } @@ -462,10 +467,11 @@ JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_i jit.make("CMFLA_IS_CAUSAL", 1); if (desc->has_xattention) { jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE_XATTN); + jit.make("SPARSE_BLOCK_SIZE", xattn_block_size); } else { jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE); + jit.make("SPARSE_BLOCK_SIZE", 1); } - jit.make("SPARSE_BLOCK_SIZE", xattn_block_size); jit.make("Q_STEP", get_q_step(xe_arch, true)); if (get_kv_compressed(params)) { diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 6d4e452fe60c48..1f43ec08de16a6 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -35,7 +35,7 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no data_layout.data_padding = padding(); - const auto& key_cache_idx = cldnn::paged_attention::PagedAttentionInputIdx::KEY_CACHE; + size_t key_cache_idx = cldnn::paged_attention::PagedAttentionInputIdx::KEY_CACHE; const auto& key_cache_ps = impl_param.get_input_layout(key_cache_idx).get_partial_shape(); const auto& key_cache_quant_mode = impl_param.get_program().get_config().get_key_cache_quant_mode(); bool key_cache_compressed = impl_param.get_input_layout(key_cache_idx).data_type == ov::element::i8 || @@ -43,6 +43,7 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no size_t expected_block_size = paged_attention::block_size; if (desc->has_xattention) { expected_block_size = paged_attention::block_size_xattn; + key_cache_idx -= 1; } if (key_cache_compressed && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { expected_block_size += 4; @@ -51,9 +52,9 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no "[GPU] Paged Attention key cache quantization mode mismatch: prim.is_key_by_channel : ", desc->is_key_by_channel, " but exec_config : ", impl_param.get_program().get_config().get_key_cache_quant_mode()); bool valid_block_size = key_cache_ps.is_dynamic() || - (key_cache_ps[key_cache_idx-1].get_length() == static_cast(expected_block_size)); + (key_cache_ps[key_cache_idx].get_length() == static_cast(expected_block_size)); OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation for key cache quant mode " - , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[key_cache_idx-1].get_length()); + , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[key_cache_idx].get_length()); std::vector output_layouts{ data_layout }; if (desc->has_scores_output()) { diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 586cf542822bc6..c1d2caed8ddfe9 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -96,10 +96,8 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared if (xattention_threshold_input && xattention_threshold_input->get_output_partial_shape(0).is_dynamic()) { // TODO: enable xattention_threshold_input prim.has_xattention = true; - } else if(rt_info.find("k_block_size") != rt_info.end()) { - if(rt_info.at("k_block_size").as() == 256) { - prim.has_xattention = true; - } + } else if(key_cache_ps[3].get_length() == k_head_size && key_cache_ps[2].get_length() == 256) { + prim.has_xattention = true; } prim.is_key_by_channel = p.get_config().get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index e51f9c1b180b74..0e66bf84be4df5 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -508,16 +508,34 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // To handle this case, "KeepConstPrecision" is executed again. manager.register_pass(supported_woq_types, !device_info.supports_immad); + bool use_xattention = false; + const auto& parameters = func->get_parameters(); + for (const auto& param : parameters) { + if (param->get_friendly_name() == "xattention_block_size") { + use_xattention = true; + } + } + ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); kv_cache_config.inferencePrecision = infer_precision; - kv_cache_config.keyCacheBlockSize = 16; - kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; + if (use_xattention) { + kv_cache_config.keyCacheBlockSize = 256; + kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; + } else { + kv_cache_config.keyCacheBlockSize = 16; + kv_cache_config.keyCacheDimOrder = {0, 1, 3, 2}; + } kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; - kv_cache_config.valueCacheBlockSize = 16; - kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; + if (use_xattention) { + kv_cache_config.valueCacheBlockSize = 256; + kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; + } else { + kv_cache_config.valueCacheBlockSize = 16; + kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; + } kv_cache_config.valueCacheQuantBychannel = false; kv_cache_config.valueCacheGroupSize = 0; From 50480814954fa140feb9becb555ea7a86c3e2e02 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Fri, 26 Sep 2025 12:29:11 +0800 Subject: [PATCH 32/96] fix partition size --- .../src/graph/impls/cm/paged_attention.cpp | 4 ++-- .../src/graph/impls/cm/paged_attention_gen.cpp | 13 ++++++------- .../src/graph/impls/cm/paged_attention_gen.hpp | 4 ++-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index b246637d3e7c9e..93a704bd89f801 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -70,7 +70,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const auto max_context_len = get_max_context_len(params); rt_params->max_context_len = max_context_len; - rt_params->partition_size = get_partition_size(); + rt_params->partition_size = get_partition_size(desc->has_xattention); rt_params->num_of_partitions = ceil_div(max_context_len, rt_params->partition_size); rt_params->stage = get_paged_attention_stage(params); @@ -162,7 +162,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { } else { stage = get_paged_attention_stage(params); const auto max_context_len = get_max_context_len(params); - partition_size = get_partition_size(); + partition_size = get_partition_size(desc->has_xattention); num_of_partitions = ceil_div(max_context_len, partition_size); } GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::get_internal_buffer_descs(): stage = " << static_cast(stage) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index c3062cd9a5c75a..0f06595b9487a8 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -171,22 +171,21 @@ int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAtt return aligned_seq_len; } -size_t get_partition_size() { +size_t get_partition_size(const bool has_xattention) { // size_t k_partition_blok_num = (kv_len + 8191) / 8192; // if (k_partition_blok_num < 1) // k_partition_blok_num = 1; // const size_t k_partition_blok_num = 16; // return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; // 128 - // TODO: how to change PA_KV_CACHE_BLOCK_SIZE here - if (PA_KV_CACHE_BLOCK_SIZE < 128) { + if (!has_xattention && PA_KV_CACHE_BLOCK_SIZE < 128) { return 128; } else { - return PA_KV_CACHE_BLOCK_SIZE; + return PA_KV_CACHE_BLOCK_SIZE_XATTN; } } -size_t get_partition_num(const size_t kv_len) { - const size_t partition_size = get_partition_size(); +size_t get_partition_num(const size_t kv_len, const bool has_xattention) { + const size_t partition_size = get_partition_size(has_xattention); const size_t partition_num = (kv_len + partition_size - 1) / partition_size; return partition_num; @@ -548,7 +547,7 @@ JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_ // jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); auto desc = params.typed_desc(); const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)); - const size_t kv_partition_size = get_partition_size(); + const size_t kv_partition_size = get_partition_size(desc->has_xattention); auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; jit.make("KV_PARTITION_SIZE", kv_partition_size); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index d89633e7498153..314032f2fd0afe 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -59,8 +59,8 @@ int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAtt PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param); size_t get_max_context_len(const kernel_impl_params& params); size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); -size_t get_partition_size(); -size_t get_partition_num(const size_t kv_len); +size_t get_partition_size(const bool has_xattention); +size_t get_partition_num(const size_t kv_len, const bool has_xattention); const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx); inline size_t get_xattn_block_size(const kernel_impl_params& impl_param) { From 0c8c029ad4774be696c33bfd836053e3295064de Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Thu, 9 Oct 2025 13:26:36 +0800 Subject: [PATCH 33/96] fix condition of xattn stages --- src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 93a704bd89f801..8add87a338c756 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -51,7 +51,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { add_stage(pa_single_token_finalization, params); add_stage(pa_multi_token, params); const size_t xattn_block_size = get_xattn_block_size(params); - if (xattn_block_size > 1) { + if (desc->has_xattention && xattn_block_size > 1) { add_stage(xattn_estimate_gemmqk, params); add_stage(xattn_estimate_find_block, params); add_stage(xattn_estimate_post_proc, params); @@ -194,7 +194,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg const size_t block_size = get_xattn_block_size(params); - if (block_size > 1) { + if (desc->has_xattention && block_size > 1) { OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); const uint32_t q_block_pad = ceil_div(q_len, block_size); const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); From 6fbf07b35f0ca7bf508696679a321176209db0eb Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Thu, 9 Oct 2025 17:28:51 +0800 Subject: [PATCH 34/96] Add xAttention reference operation and test --- .../include/openvino/reference/xattention.hpp | 433 +++++++++ src/core/tests/reference/xattention.cpp | 550 +++++++++++ .../test_cases/paged_attention_gpu_test.cpp | 640 +------------ .../unit/test_cases/xattention_gpu_test.cpp | 871 ++++++++++++++++++ .../test_utils/paged_attention_gpu_test.hpp | 658 +++++++++++++ 5 files changed, 2514 insertions(+), 638 deletions(-) create mode 100644 src/core/reference/include/openvino/reference/xattention.hpp create mode 100644 src/core/tests/reference/xattention.cpp create mode 100644 src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp create mode 100644 src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp diff --git a/src/core/reference/include/openvino/reference/xattention.hpp b/src/core/reference/include/openvino/reference/xattention.hpp new file mode 100644 index 00000000000000..49e01042417caf --- /dev/null +++ b/src/core/reference/include/openvino/reference/xattention.hpp @@ -0,0 +1,433 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "openvino/reference/divide.hpp" +#include "openvino/reference/matmul.hpp" +#include "openvino/reference/softmax.hpp" +#include "openvino/reference/transpose.hpp" +#include "openvino/runtime/tensor.hpp" + +namespace ov::reference { + +using XAttentionBlockIndex = + std::pair; // .first is the *query* dimension block index, .second is *key* +using XAttentionRetainedBlockIndices = std::set; +using XAttentionRetainedBlockIndicesForAllHeads = std::vector; + +/** @brief Reference implementation of the XAttention sparse attention prefill mechanism + * (https://arxiv.org/abs/2503.16428) */ +template +class XAttentionBlockSelector { +public: + /** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the + * smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of + * attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of + * the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 + * corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained. + * @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension, + * key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks + * according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with + * the same granularity, i.e. the resulting blocks have equal size on both dimensions. Essentially `block_size` + * defines the granularity of the eventual sparse attention computations. Must be a multiple of `stride`. + * @param stride The stride at which the full attention matrix is subsampled in a block-antidiagonal fashion to + * estimate the block importance. Note that the full attention matrix is not computed, instead the original query + * and key matrices are reshaped appropriately so that only the necessary elements are computed. Ideally, the + * computational complexity of the entire block estimation operation is `stride` times lower than the full attention + * matrix computation. + * */ + XAttentionBlockSelector(double threshold, size_t block_size, size_t stride) + : m_threshold(threshold), + m_block_size(block_size), + m_stride(stride) { + OPENVINO_ASSERT(m_block_size % m_stride == 0); + } + + /** Assuming the input tensor is either a query tensor or key tensor, reshapes it in a diagonal or antidiagonal + * fashion as appropriate so that the resulting matrices could be used to compute the block-antidiagonal subset of + * the attention matrix in further operations. For the query tensor, the antidiagonal reshaping should be applied, + * and diagonal - for the key tensor. Note that for the diagonal reshaping the data layout is effectively unchanged + * and only the shape can be adjusted in the efficient implementation of the same operation in HW. + * @param input_data Pointer to the input tensor data (query or key) + * @param input_shape Shape of the input tensor data (query or key). Expected shape is [num_heads, num_tokens, + * head_size], where `num_tokens` must be a multiple of `stride`. + * @param output_data Pointer to the output tensor data (reshaped query or key storage) + * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_tokens / stride, head_size * + * stride] + * @param is_antidiagonal Whether to reshape antidiagonally (true) or diagonally (false). Use `true` for query + * tensor and `false` for key tensor. + */ + void diagonal_reshape(const T* input_data, + const Shape& input_shape, + T* output_data, + const Shape& out_shape, + bool is_antidiagonal) { + OPENVINO_ASSERT(input_shape.size() == 3); // [num_heads, num_tokens, head_size] + OPENVINO_ASSERT(out_shape.size() == 3); + OPENVINO_ASSERT(input_shape[0] == out_shape[0]); + OPENVINO_ASSERT(input_shape[1] % m_stride == 0); + OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); + OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); + + size_t num_stride_steps = input_shape[1] / m_stride; + for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) { + size_t head_offset = head_idx * input_shape[1] * input_shape[2]; + for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) { + for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) { + size_t input_offset = head_offset; + size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2]; + if (is_antidiagonal) { + input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2]; + } else { + input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2]; + } + std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T)); + } + } + } + } + + /** Performs a matrix multiplication on the input tensors Q and K and scales the result in a typical attention op + * fashion, i.e. Q @ K^T / (sqrt(D) * S). Additionally rescales by the stride value, as compared to the regular + * attention. + * @param reshaped_query_data Pointer to the reshaped query input. + * @param reshaped_key_data Pointer to the reshaped key input. + * @param reshaped_query_shape Shape of the reshaped query input data. Expected shape is [num_heads, + * num_query_tokens / stride, head_size * stride]. + * @param reshaped_key_shape Shape of the reshaped key input data. Expected shape is [num_heads, num_key_tokens / + * stride, head_size * stride]. + * @param out Pointer to the output tensor data (attention logit scores) + * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / stride, + * num_key_tokens / stride] + */ + void transpose_matmul_scale(const T* reshaped_query_data, + const T* reshaped_key_data, + const Shape& reshaped_query_shape, + const Shape& reshaped_key_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(reshaped_key_shape.size() == 3); + OPENVINO_ASSERT(reshaped_query_shape.size() == 3); + OPENVINO_ASSERT(reshaped_query_shape[0] == reshaped_key_shape[0]); + OPENVINO_ASSERT(reshaped_query_shape[2] == reshaped_key_shape[2]); + + OPENVINO_ASSERT(out_shape.size() == 3); + OPENVINO_ASSERT(out_shape[0] == reshaped_query_shape[0]); + OPENVINO_ASSERT(out_shape[1] == reshaped_query_shape[1]); + OPENVINO_ASSERT(out_shape[2] == reshaped_key_shape[1]); + + ov::reference::matmul(reshaped_query_data, + reshaped_key_data, + out, + reshaped_query_shape, + reshaped_key_shape, + out_shape, + /* transpose_arg0 = */ false, + /* transpose_arg1 = */ true); + + size_t out_size = out_shape[0] * out_shape[1] * out_shape[2]; + + for (size_t i = 0; i < out_size; i++) { + // The D in the formula above refers to the original head dimension, while + // reshaped_query_shape[2] had been scaled in the process of reshaping, therefore + // the formula is also adjusted: + out[i] = out[i] / std::sqrt(reshaped_query_shape[2] * m_stride); + } + } + + /** Performs a softmax operation on the last dimension of the rank-3 input tensor. + * @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax). + * @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens / + * stride, num_key_tokens / stride]. + * @param out Pointer to the output tensor data (attention scores) + * @param out_shape Shape of the output tensor data. Expected shape is strictly equal to + * `reshaped_qk_product_shape`. + */ + void softmax(const T* reshaped_qk_product_data, + const Shape& reshaped_qk_product_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); + OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); + ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2}); + } + + /** Divides the input rank-3 tensor into blocks along last two dimensions, performs the addition of the values + * inside each block and outputs each block sum into corresponding positions in the output tensor downsampled along + * the same dimensions. The output tensor dimensions are such that the query and key token dimensions are + * downsampled by `block_size` when compared to the *original* query and key tensors. + * @param attention_scores_data Pointer to the attention score input. + * @param attention_score_shape Shape of the attention score input tensor. Expected shape is [num_heads, + * num_query_tokens / stride, num_key_tokens / stride], where `num_query_tokens` and `num_key_tokens` must be + * multiples of `block_size`. + * @param out Pointer to the output tensor data (block sums) + * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size, + * num_key_tokens / block_size]. + */ + void block_sum_attention_scores(const T* attention_scores_data, + const Shape& attention_scores_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(attention_scores_shape.size() == 3); // [num_heads, query_antidiagonals, key_antidiagonals] + size_t antidiagonals_per_xattention_block = m_block_size / m_stride; + OPENVINO_ASSERT(attention_scores_shape[1] % antidiagonals_per_xattention_block == 0); + OPENVINO_ASSERT(attention_scores_shape[2] % antidiagonals_per_xattention_block == 0); + + OPENVINO_ASSERT(out_shape[0] == attention_scores_shape[0]); + OPENVINO_ASSERT(out_shape[1] == + attention_scores_shape[1] / antidiagonals_per_xattention_block); // query length, blocked + OPENVINO_ASSERT(out_shape[2] == + attention_scores_shape[2] / antidiagonals_per_xattention_block); // key length, blocked + + std::memset(out, 0, out_shape[0] * out_shape[1] * out_shape[2] * sizeof(T)); + + for (size_t head_idx = 0; head_idx < attention_scores_shape[0]; head_idx++) { + size_t in_head_offset = head_idx * attention_scores_shape[1] * attention_scores_shape[2]; + size_t out_head_offset = head_idx * out_shape[1] * out_shape[2]; + for (size_t query_len_idx = 0; query_len_idx < attention_scores_shape[1]; query_len_idx++) { + for (size_t key_len_idx = 0; key_len_idx < attention_scores_shape[2]; key_len_idx++) { + size_t query_block_idx = query_len_idx / antidiagonals_per_xattention_block; + size_t key_block_idx = key_len_idx / antidiagonals_per_xattention_block; + auto target_block_sum_ptr = out + out_head_offset + query_block_idx * out_shape[2] + key_block_idx; + *target_block_sum_ptr += *(attention_scores_data + in_head_offset + + query_len_idx * attention_scores_shape[2] + key_len_idx); + } + } + } + } + + /** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, + * so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the + * total element sum. + * @param blocked_scores_data Pointer to the blocked score input. + * @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads, + * num_query_tokens / block_size, num_key_tokens / block_size] + * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block + * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks + * corresponding to the property described above. + */ +// template +// void print_blocked_attention_scores(const T* data, +// size_t num_heads, +// size_t num_q_blocks, +// size_t num_k_blocks) { +// std::cout << "blocked_attention_scores shape: [" +// << num_heads << ", " << num_q_blocks << ", " << num_k_blocks << "]\n"; + +// for (size_t h = 0; h < num_heads; ++h) { +// std::cout << "Head " << h << ":\n"; +// std::cout << std::setw(8) << ""; +// for (size_t k = 0; k < num_k_blocks; ++k) { +// std::cout << std::setw(12) << ("K" + std::to_string(k)); +// } +// std::cout << "\n"; + +// for (size_t q = 0; q < num_q_blocks; ++q) { +// std::cout << std::setw(6) << ("Q" + std::to_string(q)) << " "; +// double row_sum = 0.0; +// for (size_t k = 0; k < num_k_blocks; ++k) { +// size_t idx = h * (num_q_blocks * num_k_blocks) + q * num_k_blocks + k; +// double v = static_cast(static_cast(*(data + idx))); +// row_sum += v; +// std::cout << std::setw(12) << std::fixed << std::setprecision(6) << v; +// } +// std::cout << " sum=" << std::fixed << std::setprecision(6) << row_sum << "\n"; +// } +// std::cout << std::flush; +// } +// } +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( +// const T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); +// // [num_heads, num_blocks_in_query, num_blocks_in_key] + +// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + +// struct IndexAndScore { +// XAttentionBlockIndex idx; +// T score; +// }; + +// const size_t num_heads = blocked_attention_scores_shape[0]; +// const size_t num_q_blocks = blocked_attention_scores_shape[1]; +// const size_t num_k_blocks = blocked_attention_scores_shape[2]; +// print_blocked_attention_scores(blocked_attention_scores_data, num_heads, num_q_blocks, num_k_blocks); + +// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { +// size_t head_offset = head_idx * num_q_blocks * num_k_blocks; + +// for (size_t q_block_idx = 0; q_block_idx < num_q_blocks; q_block_idx++) { +// std::vector indices_and_scores; +// indices_and_scores.reserve(num_k_blocks); + +// double total_sum = 0.0; + +// for (size_t k_block_idx = 0; k_block_idx < num_k_blocks; k_block_idx++) { +// size_t target_offset = head_offset + q_block_idx * num_k_blocks + k_block_idx; +// T current_score = *(blocked_attention_scores_data + target_offset); +// indices_and_scores.push_back({{q_block_idx, k_block_idx}, current_score}); +// total_sum += current_score; +// } + +// double required_sum = m_threshold * total_sum; + +// std::sort(indices_and_scores.begin(), indices_and_scores.end(), +// [](const IndexAndScore& a, const IndexAndScore& b) { +// return a.score > b.score; +// }); + +// std::vector shifted_cumsum(num_k_blocks, 0.0); +// for (size_t i = 1; i < num_k_blocks; i++) { +// shifted_cumsum[i] = shifted_cumsum[i - 1] + indices_and_scores[i - 1].score; +// } + +// for (size_t i = 0; i < num_k_blocks; i++) { +// if (shifted_cumsum[i] < required_sum) { +// retval[head_idx].insert(indices_and_scores[i].idx); +// } +// } +// } +// } + +// return retval; +// } + + XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, + const Shape& blocked_attention_scores_shape) { + OPENVINO_ASSERT(blocked_attention_scores_shape.size() == + 3); // [num_heads, num_blocks_in_query, num_blocks_in_key] + + auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + + struct IndexAndScore { + XAttentionBlockIndex idx; + T score; + bool operator<(const IndexAndScore& rhs) const { + return score < rhs.score; + } + }; + + for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) { + size_t head_offset = head_idx * blocked_attention_scores_shape[1] * blocked_attention_scores_shape[2]; + std::priority_queue indices_and_scores_queue; + double total_sum = 0.0; + for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) { + for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) { + size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx; + T current_score = *(blocked_attention_scores_data + target_offset); + indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score}); + total_sum += current_score; + } + } + double cumsum = 0.0; + double required_sum = m_threshold * total_sum; + while (cumsum < required_sum && !indices_and_scores_queue.empty()) { + auto index_and_largest_score = indices_and_scores_queue.top(); + indices_and_scores_queue.pop(); + cumsum += index_and_largest_score.score; + retval[head_idx].insert(index_and_largest_score.idx); + } + } + return retval; + } + + /** Applies XAttention to the provided query and key matrices, returning the subset of the most important blocks for + * each attention head, according to the configured block size and threshold, which are to be preserved in the + * subsequent sparse attention computation. + * @param query_data Pointer to the query input tensor data + * @param query_shape Shape of the query input tensor data. Expected shape is [num_heads, num_query_tokens, + * head_size], where `num_query_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if + * necessary to do so in the real-world scenario. + * @param key_data Pointer to the key input tensor data + * @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size], + * where `num_key_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if necessary to + * do so in the real-world scenario. + * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block + * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks that + * must be preserved in the sparse attention computation. Indices are given in units of XAttention-specific + * `block_size` (as configured), which may differ from the block size in the paged attention implementation. + */ + XAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, + const Shape& query_shape, + const T* key_data, + const Shape& key_shape) { + OPENVINO_ASSERT(query_shape.size() == 3); // [num_heads, query_token_len, head_dim] + OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim] + + OPENVINO_ASSERT(key_shape[0] == query_shape[0]); + OPENVINO_ASSERT(key_shape[2] == query_shape[2]); + + OPENVINO_ASSERT(query_shape[1] % m_stride == 0); + OPENVINO_ASSERT(key_shape[1] % m_stride == 0); + + OPENVINO_ASSERT(query_shape[1] % m_block_size == 0); + OPENVINO_ASSERT(key_shape[1] % m_block_size == 0); + + Shape reshaped_query_shape = {query_shape[0], query_shape[1] / m_stride, query_shape[2] * m_stride}; + auto q_buf = allocate_buf(reshaped_query_shape); + diagonal_reshape(query_data, query_shape, q_buf.get(), reshaped_query_shape, /* is_antidiagonal = */ true); + + Shape reshaped_key_shape = {key_shape[0], key_shape[1] / m_stride, key_shape[2] * m_stride}; + auto k_buf = allocate_buf(reshaped_key_shape); + diagonal_reshape(key_data, key_shape, k_buf.get(), reshaped_key_shape, /* is_antidiagonal = */ false); + + Shape transpose_matmul_scaled_shape = {key_shape[0], query_shape[1] / m_stride, key_shape[1] / m_stride}; + auto qk_buf = allocate_buf(transpose_matmul_scaled_shape); + transpose_matmul_scale(q_buf.get(), + k_buf.get(), + reshaped_query_shape, + reshaped_key_shape, + qk_buf.get(), + transpose_matmul_scaled_shape); + q_buf.reset(); + k_buf.reset(); + + Shape attention_scores_shape = transpose_matmul_scaled_shape; + auto attn_score_buf = allocate_buf(attention_scores_shape); + softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); + qk_buf.reset(); + + size_t antidiagonals_per_xattention_block = m_block_size / m_stride; + Shape block_sum_shape = {attention_scores_shape[0], + attention_scores_shape[1] / antidiagonals_per_xattention_block, + attention_scores_shape[2] / antidiagonals_per_xattention_block}; + auto block_sum_buf = allocate_buf(block_sum_shape); + block_sum_attention_scores(attn_score_buf.get(), attention_scores_shape, block_sum_buf.get(), block_sum_shape); + attn_score_buf.reset(); + + auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); + block_sum_buf.reset(); + + return selected_block_indices; + } + + /** + * @param shape Shape of a tensor + * @return A shared_ptr owning a buffer that can be used to store tensor data for the given shape. + * */ + std::shared_ptr allocate_buf(const Shape& shape) { + return std::shared_ptr(new T[ov::shape_size(shape)]); + } + + /** + * @param token_length An integer value + * @return The closest multiple of `block_size` to `token_length`, rounding up. + * */ + size_t pad_to_block(size_t token_length) { + return (token_length + m_block_size - 1) / m_block_size * m_block_size; + } + + double m_threshold; + size_t m_block_size; + size_t m_stride; +}; + +} // namespace ov::reference \ No newline at end of file diff --git a/src/core/tests/reference/xattention.cpp b/src/core/tests/reference/xattention.cpp new file mode 100644 index 00000000000000..78ad6744c17053 --- /dev/null +++ b/src/core/tests/reference/xattention.cpp @@ -0,0 +1,550 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include + +double DEFAULT_THRESHOLD = 0.8; +size_t DEFAULT_BLOCK_SIZE = 32; +size_t DEFAULT_STRIDE = 8; + +struct E2EBlockSelectTestData { + ov::Shape q_shape; + std::vector q_data; + ov::Shape k_shape; + std::vector k_data; + double threshold; + size_t block_size; + size_t stride; +}; + +using XAttentionE2EBlockSelectTest = ::testing::TestWithParam; + +std::vector E2E_BLOCK_SELECT_TEST_CASES = {{ + {2, 4, 4}, + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + {2, 4, 4}, + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + + /* threshold = */ 0.8, + /* block_size = */ 2, + /* stride = */ 2, +}}; + +TEST_P(XAttentionE2EBlockSelectTest, SelectsBlocksWithoutThrowing) { + auto test_struct = GetParam(); + ov::reference::XAttentionBlockSelector selector(test_struct.threshold, + test_struct.block_size, + test_struct.stride); + + EXPECT_NO_THROW(selector.select_blocks(test_struct.q_data.data(), + test_struct.q_shape, + test_struct.k_data.data(), + test_struct.k_shape)); +}; + +INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionE2EBlockSelectTest, ::testing::ValuesIn(E2E_BLOCK_SELECT_TEST_CASES)); + +struct DiagonalReshapeTestData { + ov::Shape in_shape; + std::vector in_data; + bool is_antidiagonal; + size_t block_size; + size_t stride; + ov::Shape out_shape; + std::vector ref_out_data; +}; + +using XAttentionDiagonalReshapeTest = ::testing::TestWithParam; + +std::vector DIAGONAL_RESHAPE_TEST_CASES = { + { + {2, 4, 4}, + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + + /* is_antidiagonal = */ true, + /* block_size = */ 2, + /* stride = */ 2, + {2, 2, 8}, + + // clang-format off + { + 4.534, -5.908, -9.388, 2.356, -6.624, -8.463, 7.474, 9.879, + 7.889, -5.721, 5.507, 4.295, 3.144, 8.512, 8.518, -8.386, + + -7.173, 4.450, 6.705, -7.035, 3.469, 7.633, 7.244, -6.844, + -8.248, -9.797, -7.907, -4.513, 7.497, 8.186, -8.658, -4.796, + }, + // clang-format on + }, + { + {2, 4, 4}, + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + + /* is_antidiagonal = */ false, + /* block_size = */ 2, + /* stride = */ 2, + {2, 2, 8}, + + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + }, + { + {2, 9, 2}, + // clang-format off + { + 1.110, -4.244, + 3.530, -1.083, + 3.664, -2.459, + 3.930, -2.122, + -4.142, 2.837, + -7.413, 5.855, + 1.354, -7.748, + 0.264, 7.095, + -8.410, 6.247, + + -7.832, 9.163, + -7.414, -3.682, + -5.429, 7.854, + 1.767, 5.950, + -0.841, 1.935, + 3.568, 8.530, + 9.438, -2.421, + -5.892, 7.820, + -9.869, -7.636 + }, + // clang-format on + + /* is_antidiagonal = */ true, + /* block_size = */ 9, + /* stride = */ 3, + {2, 3, 6}, + + // clang-format off + { + -8.410, 6.247, 0.264, 7.095, 1.354, -7.748, + -7.413, 5.855, -4.142, 2.837, 3.930, -2.122, + 3.664, -2.459, 3.530, -1.083, 1.110, -4.244, + + -9.869, -7.636, -5.892, 7.820, 9.438, -2.421, + 3.568, 8.530, -0.841, 1.935, 1.767, 5.950, + -5.429, 7.854, -7.414, -3.682, -7.832, 9.163, + }, + // clang-format on + }, + { + {2, 9, 2}, + // clang-format off + { + 1.110, -4.244, + 3.530, -1.083, + 3.664, -2.459, + 3.930, -2.122, + -4.142, 2.837, + -7.413, 5.855, + 1.354, -7.748, + 0.264, 7.095, + -8.410, 6.247, + + -7.832, 9.163, + -7.414, -3.682, + -5.429, 7.854, + 1.767, 5.950, + -0.841, 1.935, + 3.568, 8.530, + 9.438, -2.421, + -5.892, 7.820, + -9.869, -7.636 + }, + // clang-format on + + /* is_antidiagonal = */ false, + /* block_size = */ 9, + /* stride = */ 3, + {2, 3, 6}, + + // clang-format off + { + 1.110, -4.244, 3.530, -1.083, 3.664, -2.459, + 3.930, -2.122, -4.142, 2.837, -7.413, 5.855, + 1.354, -7.748, 0.264, 7.095, -8.410, 6.247, + + -7.832, 9.163, -7.414, -3.682, -5.429, 7.854, + 1.767, 5.950, -0.841, 1.935, 3.568, 8.530, + 9.438, -2.421, -5.892, 7.820, -9.869, -7.636 + }, + // clang-format on + }, +}; + +TEST_P(XAttentionDiagonalReshapeTest, ReshapesDiagonally) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); + + ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, + test_struct.block_size, + test_struct.stride); + std::vector test_out_data(test_struct.ref_out_data.size()); + selector.diagonal_reshape(test_struct.in_data.data(), + test_struct.in_shape, + test_out_data.data(), + test_struct.out_shape, + test_struct.is_antidiagonal); + EXPECT_EQ(test_out_data, test_struct.ref_out_data); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + XAttentionDiagonalReshapeTest, + ::testing::ValuesIn(DIAGONAL_RESHAPE_TEST_CASES)); + +struct TransposeMatmulScaleTestData { + ov::Shape reshaped_query_shape; + std::vector reshaped_query_data; + ov::Shape reshaped_key_shape; + std::vector reshaped_key_data; + size_t block_size; + size_t stride; + ov::Shape out_shape; + std::vector ref_out_data; +}; + +using XAttentionTransposeMatmulScaleTest = ::testing::TestWithParam; + +std::vector TRANSPOSE_MATMUL_SCALE_TEST_CASES = { + { + {2, 2, 8}, + // clang-format off + { + 4.534, -5.908, -9.388, 2.356, -6.624, -8.463, 7.474, 9.879, + 7.889, -5.721, 5.507, 4.295, 3.144, 8.512, 8.518, -8.386, + + -7.173, 4.450, 6.705, -7.035, 3.469, 7.633, 7.244, -6.844, + -8.248, -9.797, -7.907, -4.513, 7.497, 8.186, -8.658, -4.796, + }, + // clang-format on + + {2, 3, 8}, + + // clang-format off + { + -2.731, -0.545, 6.128, -6.175, -2.198, -1.275, -8.617, -0.683, + 3.085, 7.929, -1.127, 5.369, -6.891, 9.582, -6.954, 1.189, + -0.610, -6.310, -9.216, -1.196, 9.509, -8.119, 4.652, -4.435, + + -0.026, -9.294, 7.862, 9.318, -6.012, 8.252, -3.224, -0.710, + -2.915, -7.362, -5.553, 0.097, -4.509, 6.993, 2.021, 2.870, + -3.682, 8.637, -9.922, -6.336, -2.949, 4.339, -2.807, -9.192 + }, + + /* block_size = */ 2, + /* stride = */ 2, + {2, 2, 3}, + + // clang-format off + { + -31.760349, -21.32551225, 28.723734, + -24.15923075, -3.369805999, 3.2507255, + + -7.593187497, -4.258293245, 27.08950801, + 10.21206450, 32.95415775, 33.649577 + }, + // clang-format on + }, +}; + +TEST_P(XAttentionTransposeMatmulScaleTest, TransposesMatmulsAndScales) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.reshaped_key_data.size(), ov::shape_size(test_struct.reshaped_key_shape)); + ASSERT_EQ(test_struct.reshaped_query_data.size(), ov::shape_size(test_struct.reshaped_query_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); + + ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, + test_struct.block_size, + test_struct.stride); + std::vector test_out_data(test_struct.ref_out_data.size()); + selector.transpose_matmul_scale(test_struct.reshaped_query_data.data(), + test_struct.reshaped_key_data.data(), + test_struct.reshaped_query_shape, + test_struct.reshaped_key_shape, + test_out_data.data(), + test_struct.out_shape); + + EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-8), test_struct.ref_out_data)); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + XAttentionTransposeMatmulScaleTest, + ::testing::ValuesIn(TRANSPOSE_MATMUL_SCALE_TEST_CASES)); + +struct SoftmaxTestData { + ov::Shape in_shape; + std::vector in_data; + ov::Shape out_shape; + std::vector ref_out_data; +}; + +using XAttentionSoftmaxTest = ::testing::TestWithParam; + +std::vector SOFTMAX_TEST_CASES = { + { + {2, 2, 4}, + // clang-format off + { + 4.534, -5.908, -9.388, 2.356, + 7.889, -5.721, 5.507, 4.295, + + -7.173, 4.450, 6.705, -7.035, + -8.248, -9.797, -7.907, -4.513 + }, + // clang-format on + + {2, 2, 4}, + + // clang-format off + { + 0.898232, 2.62111e-05, 8.07497e-07, 0.101741, + 0.892973, 1.09671e-06, 0.08248, 0.0245462, + + 8.50252e-07, 0.0949189, 0.905079, 9.76069e-07, + 0.0224685, 0.00477366, 0.0315986, 0.941159 + }, + }, +}; + +TEST_P(XAttentionSoftmaxTest, SoftmaxIsCorrect) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); + + ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, + DEFAULT_BLOCK_SIZE, + DEFAULT_STRIDE); + std::vector test_out_data(test_struct.ref_out_data.size()); + selector.softmax(test_struct.in_data.data(), test_struct.in_shape, test_out_data.data(), test_struct.out_shape); + + EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-5), test_struct.ref_out_data)); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + XAttentionSoftmaxTest, + ::testing::ValuesIn(SOFTMAX_TEST_CASES)); + +struct BlockSumTestData { + ov::Shape in_shape; + std::vector in_data; + size_t block_size; + size_t stride; + ov::Shape out_shape; + std::vector ref_out_data; +}; + +using XAttentionBlockSumTest = ::testing::TestWithParam; + +std::vector BLOCK_SUM_TEST_CASES = { + { + {2, 4, 8}, + // clang-format off + { + 0.1117, 0.0780, 0.1347, 0.0885, 0.1942, 0.0922, 0.1184, 0.1824, + 0.1488, 0.1766, 0.0852, 0.1239, 0.0930, 0.1220, 0.1367, 0.1138, + 0.1410, 0.0861, 0.0774, 0.1325, 0.1478, 0.1689, 0.0885, 0.1579, + 0.1248, 0.1038, 0.1842, 0.0935, 0.1813, 0.0890, 0.0897, 0.1336, + + 0.0905, 0.1049, 0.1263, 0.0953, 0.1018, 0.1297, 0.1659, 0.1855, + 0.1373, 0.1791, 0.1005, 0.1286, 0.1492, 0.1373, 0.0820, 0.0860, + 0.0997, 0.1285, 0.0786, 0.1366, 0.1963, 0.0904, 0.1488, 0.1211, + 0.1859, 0.1174, 0.1364, 0.0930, 0.1028, 0.1034, 0.1699, 0.0912 + }, + // clang-format on + + /* block_size = */ 8, + /* stride = */ 4, + {2, 2, 4}, + + // clang-format off + { + 0.5151, 0.4323, 0.5014, 0.5513, + 0.4557, 0.4876, 0.5870, 0.4697, + + 0.5118, 0.4507, 0.5180, 0.5194, + 0.5315, 0.4446, 0.4929, 0.5310 + }, + }, +}; +TEST_P(XAttentionBlockSumTest, BlockSumIsCorrect) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); + + ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, + test_struct.block_size, + test_struct.stride); + std::vector test_out_data(test_struct.ref_out_data.size()); + selector.block_sum_attention_scores(test_struct.in_data.data(), test_struct.in_shape, test_out_data.data(), test_struct.out_shape); + + EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-5), test_struct.ref_out_data)); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + XAttentionBlockSumTest, + ::testing::ValuesIn(BLOCK_SUM_TEST_CASES)); + +struct BlockSelectTestData { + ov::Shape in_shape; + std::vector in_data; + double threshold; + ov::reference::XAttentionRetainedBlockIndicesForAllHeads ref_retained_block_indices; +}; + +using XAttentionBlockSelectTest = ::testing::TestWithParam; + +std::vector BLOCK_SELECT_TEST_CASES = { + { + {2, 2, 4}, + // clang-format off + { + 0.5151, 0.4323, 0.5014, 0.5513, + 0.4557, 0.4876, 0.5870, 0.4697, + + 0.5118, 0.4507, 0.5180, 0.5194, + 0.5315, 0.4446, 0.4929, 0.5310 + }, + // clang-format on + /* threshold = */ 0.25, + { + {{1, 2}, {0, 3}}, + {{1, 0}, {1, 3}}, + }}, + + {{2, 2, 4}, + // clang-format off + { + 0.5151, 0.4323, 0.5014, 0.5513, + 0.4557, 0.4876, 0.5870, 0.4697, + + 0.5118, 0.4507, 0.5180, 0.5194, + 0.5315, 0.4446, 0.4929, 0.5310 + }, + // clang-format on + /* threshold = */ 0.35, + { + {{1, 2}, {0, 3}, {0, 0}}, + {{1, 0}, {1, 3}, {0, 3}}, + }}, + {{2, 2, 4}, + // clang-format off + { + 0.5151, 0.4323, 0.5014, 0.5513, + 0.4557, 0.4876, 0.5870, 0.4697, + + 0.5118, 0.4507, 0.5180, 0.5194, + 0.5315, 0.4446, 0.4929, 0.5310 + }, + // clang-format on + /* threshold = */ 0.1, + { + {{1, 2}}, + {{1, 0}}, + }}, + {{2, 2, 4}, + // clang-format off + { + 0.5151, 0.4323, 0.5014, 0.5513, + 0.4557, 0.4876, 0.5870, 0.4697, + + 0.5118, 0.4507, 0.5180, 0.5194, + 0.5315, 0.4446, 0.4929, 0.5310 + }, + // clang-format on + /* threshold = */ 0.0, + { + {}, + {}, + }}, + {{2, 2, 4}, + // clang-format off + { + 0.5151, 0.4323, 0.5014, 0.5513, + 0.4557, 0.4876, 0.5870, 0.4697, + + 0.5118, 0.4507, 0.5180, 0.5194, + 0.5315, 0.4446, 0.4929, 0.5310 + }, + // clang-format on + /* threshold = */ 1.0, + { + {{1, 2}, {0, 3}, {0, 0}, {0, 2}, {1, 1}, {1, 3}, {1, 0}, {0, 1}}, + {{1, 0}, {1, 3}, {0, 3}, {0, 2}, {0, 0}, {1, 2}, {0, 1}, {1, 1}}, + }}, +}; + +TEST_P(XAttentionBlockSelectTest, BlockSelectionIsCorrect) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + + ov::reference::XAttentionBlockSelector selector(test_struct.threshold, DEFAULT_BLOCK_SIZE, DEFAULT_STRIDE); + auto test_result = selector.get_block_indices_to_keep(test_struct.in_data.data(), test_struct.in_shape); + + EXPECT_EQ(test_result, test_struct.ref_retained_block_indices); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionBlockSelectTest, ::testing::ValuesIn(BLOCK_SELECT_TEST_CASES)); \ No newline at end of file diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 0657488e706f4e..160efddcccfdec 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1,9 +1,10 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include "test_utils.h" #include "random_generator.hpp" +#include "paged_attention_gpu_test.hpp" #include #include @@ -19,643 +20,6 @@ using namespace cldnn; using namespace ov::intel_gpu; using namespace ::tests; -/* -* PagedAttention inputs: -* [0]: query -* shape: [batch_size_in_tokens, num_heads * head_size], type: f16 -* [1]: key -* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 -* [2]: value  -* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 -* [3]: key_cache -* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16 -* [4]: value_cache -* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16 -* [5]: past_lens -* shape: [batch_size_in_sequences], type: i32 -* [6]: subsequence_begins -* shape: [batch_size_in_sequences + 1], type: i32 -* [7]: block_indices -* Shape: [num_blocks], type: i32 -* [8]: block_indices_begins -* Shape: [batch_size_in_sequences + 1], type: i32 -* [9]: scale, optional -* [10]: sliding_window, optional -* [11]: alibi_slopes, optional -* [12]: max_context_len -* shape: [], type: i32 -* [13]: score_aggregation_window​, optional​, shape: [batch_size_in_sequences] -* [14]: rotated_block_indices​, optional​ -* shape: [num_rotated_blocks]​, type: i32 -* [15]: rotation_deltas​, optional​ -* shape: [num_rotated_blocks, BLOCK_SIZE]​ || [num_rotated_blocks, 1]​, type: i32 -* [16]: rotation_trig_lut​, optional​ -* shape: [max_num_batched_tokens / BLOCK_SIZE, head_size]​ || [max_num_batched_tokens, head_size], type: f16 -*/ - - -enum class ScoresMode { - DISABLED = 0, - LAST_TOKEN, - SNAPKV -}; - -struct SubsequenceDescriptor { - int num_tokens; - int past_len; -}; - -struct CacheRotationDescriptor { - bool apply_rotation; - // configures 2nd dimension of rotation_deltas - // if per_block is true, single value is used for all tokens inside the block - // otherwise, each token uses an independent value - bool per_block; -}; - -struct PagedAttentionManager { - int num_heads; - int k_head_size; - int v_head_size; - int block_size; - int sliding_window_size; - bool kv_cache_compression; - ov::internal::CacheQuantMode key_cache_quant_mode; - bool has_score_aggregation; - CacheRotationDescriptor rotation_config; - std::vector subsequence_descs; - - // per-subsequence QKV inputs - std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} - std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} - std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} - - // common PA inputs - std::vector past_lens; - std::vector subsequence_begins; - std::vector block_indices; - std::vector block_indices_begins; - std::vector max_context_len; - std::vector score_aggregation_window; - - // score aggregation related inputs - std::vector score_aggregation; - - // rotation related inputs - std::vector rotated_block_indices; - std::vector rotation_deltas; - std::vector rotation_trig_lut; - - std::vector xattention_threshold; - std::vector xattention_block_size; - std::vector xattention_stride; - - cldnn::engine& test_engine; - cldnn::stream& test_stream; - tests::random_generator& rg; - - PagedAttentionManager(tests::random_generator& rg, - cldnn::engine& engine, - cldnn::stream& stream, - const std::vector& subsequence_descs, - int num_heads, - int k_head_size, - int v_head_size, - int block_size, - int sliding_window_size, - bool kv_cache_compression, - ov::internal::CacheQuantMode key_cache_quant_mode, - bool has_score_aggregation, - CacheRotationDescriptor rotation_config) - : num_heads(num_heads) - , k_head_size(k_head_size) - , v_head_size(v_head_size) - , block_size(block_size) - , sliding_window_size(sliding_window_size) - , kv_cache_compression(kv_cache_compression) - , key_cache_quant_mode(key_cache_quant_mode) - , has_score_aggregation(has_score_aggregation) - , rotation_config(rotation_config) - , subsequence_descs(subsequence_descs) - , test_engine(engine) - , test_stream(stream) - , rg(rg) { - // init subsequence_begins and block_indices_begins - subsequence_begins.push_back(0); - block_indices_begins.push_back(0); - - int max_len = 0; - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - const auto& subsequence_desc = subsequence_descs[i]; - max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); - - query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); - key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); - value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); - - past_lens.push_back(subsequence_desc.past_len); - int subsequence_start_pos = subsequence_begins[i]; - int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens; - subsequence_begins.push_back(subsequence_end_pos); - - int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len; - int required_blocks = ceil_div(subsequence_length, block_size); - int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1; - int end_block_idx = start_block_idx + required_blocks; - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - block_indices.push_back(block_idx); - } - - int block_indices_start_pos = block_indices_begins[i]; - int block_indices_end_pos = block_indices_start_pos + required_blocks; - block_indices_begins.push_back(block_indices_end_pos); - } - max_context_len.push_back(max_len); - - if (rotation_config.apply_rotation) { - // iterate over KV-cache blocks and apply cache rotation to every second - // fully occupied block - for (size_t i = 0; i < subsequence_descs.size(); i++) { - const auto& subsequence_desc = subsequence_descs[i]; - int past_len = subsequence_desc.past_len; - int start_block_idx = block_indices_begins[i]; - for (int block_idx = 1; block_idx < past_len / block_size; block_idx++) { - if (block_idx % 2 != 0) { - rotated_block_indices.push_back(start_block_idx + block_idx); - } - } - } - - if (!rotated_block_indices.empty()) { - rotation_deltas = generate_rotation_deltas_data(rg, - max_context_len[0], - rotated_block_indices.size(), - block_size, - rotation_config.per_block); - rotation_trig_lut = generate_rotation_trig_lut_data(rg, max_context_len[0], k_head_size); - } - } - - if (has_score_aggregation) { - for (const auto& subsequence_desc : subsequence_descs) { - const auto max_tokens = 10; - auto max_window_size = std::min(subsequence_desc.num_tokens, max_tokens); - auto window_size = rg.generate_random_val(1, max_window_size); - score_aggregation.push_back(window_size); - } - } - } - - memory::ptr get_query_memory() { - return get_QKV_memory(query_data, k_head_size, false); - } - - memory::ptr get_key_memory() { - return get_QKV_memory(key_data, k_head_size, true); - } - - memory::ptr get_value_memory() { - return get_QKV_memory(value_data, v_head_size, true); - } - -#if ENABLE_PA_CM_PATH - memory::ptr get_key_cache_memory() { - auto key_cache_dt = data_types::f16; - auto adjusted_head_size = k_head_size; - if (kv_cache_compression) { - key_cache_dt = data_types::i8; - adjusted_head_size += 4; - } - - auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; - auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; - auto memory = test_engine.allocate_memory(key_cache_layout); - - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * v_head_size + - head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); - - // shape: [num_blocks, num_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); - - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; - - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } - } - } - } - } - } - - return memory; - } - -#else - memory::ptr get_key_cache_memory() { - auto key_cache_dt = data_types::f16; - auto adjusted_head_size = k_head_size; - auto adjusted_block_size = block_size; - if (kv_cache_compression) { - key_cache_dt = data_types::i8; - if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { - adjusted_block_size += 4; - } else { - adjusted_head_size += 4; - } - } - - auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, adjusted_head_size, adjusted_block_size }; - auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; - auto memory = test_engine.allocate_memory(key_cache_layout); - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; - // quantize by channel - if (kv_cache_compression && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { - std::vector token_block(block_size); - for (int token_idx = 0; token_idx < last_token_idx; ++token_idx) { - size_t input_token_offset = block_idx * block_size + token_idx; - token_block[token_idx] = *(key_data[i].data() + input_token_offset * num_heads * k_head_size + head_idx * k_head_size + k_head_size_idx); - } - auto [quantized_data, scale, zp] = quantize_data(token_block.data(), last_token_idx, true); - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * adjusted_block_size + - head_idx * adjusted_head_size * adjusted_block_size; - size_t output_offset = output_block_offset + - k_head_size_idx * adjusted_block_size; - set_values(test_stream, memory, quantized_data.data(), last_token_idx, output_offset); - size_t comp_offset = (output_offset + block_size)/2; - set_values(test_stream, memory, &scale, 1, comp_offset); - set_values(test_stream, memory, &zp, 1, comp_offset + 1); - } - } - } - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - if (kv_cache_compression) { - if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_TOKEN) { - // quantize by token - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * k_head_size + - head_idx * k_head_size; - // shape: [num_blocks, num_heads, adjusted_head_size, block_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * block_size + - head_idx * adjusted_head_size * block_size; - - auto [quantized_data, scale, zp] = quantize_data(data_ptr, k_head_size); - for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { - auto quantized_data_ptr = quantized_data.data() + k_head_size_idx; - - size_t output_offset = output_block_offset + - k_head_size_idx * block_size + - token_idx; - - set_values(test_stream, memory, quantized_data_ptr, 1, output_offset); - } - size_t comp_offset = (output_block_offset + k_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } - } else { - for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * k_head_size + - head_idx * k_head_size + k_head_size_idx; - - // shape: [num_blocks, num_heads, k_head_size, block_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * k_head_size * block_size + - head_idx * k_head_size * block_size + - k_head_size_idx * block_size + - token_idx; - - set_values(test_stream, memory, data_ptr, 1, output_offset); - } - } - } - } - } - } - } - - return memory; - } -#endif - - memory::ptr get_value_cache_memory() { - auto value_cache_dt = data_types::f16; - auto adjusted_head_size = v_head_size; - if (kv_cache_compression) { - value_cache_dt = data_types::i8; - adjusted_head_size += 4; - } - - auto num_blocks = block_indices.back() + 1; - auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; - auto value_cache_layout = layout{ value_cache_shape, value_cache_dt, format::bfyx }; - auto memory = test_engine.allocate_memory(value_cache_layout); - - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = value_data[i].data() + - input_token_offset * num_heads * v_head_size + - head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); - - // shape: [num_blocks, num_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); - - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; - - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } - } - } - } - } - } - - return memory; - } - - memory::ptr get_past_lens_memory() { - return get_memory_from_vec(past_lens); - } - - memory::ptr get_subsequence_begins_memory() { - return get_memory_from_vec(subsequence_begins); - } - - memory::ptr get_block_indices_memory() { - return get_memory_from_vec(block_indices); - } - - memory::ptr get_block_indices_begins_memory() { - return get_memory_from_vec(block_indices_begins); - } - - memory::ptr get_scale_memory() { - std::vector scale = { ov::float16(get_default_scale()) }; - return get_memory_from_vec(scale); - } - - memory::ptr get_sliding_window_memory() { - std::vector sliding_window = { 0 }; - return get_memory_from_vec(sliding_window); - } - - memory::ptr get_alibi_memory() { - std::vector alibi; - return get_memory_from_vec(alibi); - } - - memory::ptr get_max_context_len_memory() { - return get_memory_from_vec(max_context_len); - } - - memory::ptr get_score_aggregation() { - return get_memory_from_vec(score_aggregation); - } - - memory::ptr get_rotated_block_indices_memory() { - return get_memory_from_vec(rotated_block_indices); - } - - memory::ptr get_rotation_deltas_memory() { - auto mem = get_memory_from_vec(rotation_deltas); - auto layout = mem->get_layout(); - auto last_dim = rotation_config.per_block ? 1 : block_size; - layout.set_partial_shape(ov::PartialShape{ static_cast(rotated_block_indices.size()), last_dim }); - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_rotation_trig_lut_memory() { - auto mem = get_memory_from_vec(rotation_trig_lut); - auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{ max_context_len[0], k_head_size }); - - if (rotated_block_indices.empty()) { - auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{ 0, k_head_size }); - return test_engine.reinterpret_buffer(*mem, empty_layout); - } - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_xattention_threshold_memory() { - auto mem = get_memory_from_vec(xattention_threshold); - auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{ 1 }); - - if (xattention_threshold.empty()) { - auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{ 0 }); - return test_engine.reinterpret_buffer(*mem, empty_layout); - } - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_xattention_block_size_memory() { - return get_memory_from_vec(xattention_block_size); - } - - memory::ptr get_xattention_stride_memory() { - return get_memory_from_vec(xattention_stride); - } - - float get_default_scale() { - return static_cast(1.f / std::sqrt(k_head_size)); - } - -private: - template - memory::ptr get_memory_from_vec(std::vector& input_data) { - auto data_size = input_data.empty() ? 1 : input_data.size(); - auto shape = ov::PartialShape{ static_cast(data_size) }; - auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; - auto memory = test_engine.allocate_memory(layout); - - if (input_data.empty()) { - auto shape = ov::PartialShape{0}; - auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; - return test_engine.reinterpret_buffer(*memory, layout); - } - - set_values(test_stream, memory, input_data.data(), input_data.size(), 0); - - return memory; - } - - memory::ptr get_QKV_memory(std::vector>& input_data, int k_head_size, bool skip_past_len) { - int total_tokens = 0; - for (const auto& subsequence_desc : subsequence_descs) - total_tokens += subsequence_desc.num_tokens; - - auto query_shape = ov::PartialShape{ total_tokens, num_heads * k_head_size }; - auto query_layout = layout{ query_shape, data_types::f16, format::bfyx }; - auto memory = test_engine.allocate_memory(query_layout); - - for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { - for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - size_t input_token_offset = token_idx; - // as generated data stored in vectors includes past_len, ignore it for KV inputs - if (skip_past_len) - input_token_offset += subsequence_descs[subsequence_idx].past_len; - - ov::float16* data_ptr = input_data[subsequence_idx].data() + - input_token_offset * num_heads * k_head_size + - head_idx * k_head_size; - - size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; - size_t output_offset = output_token_offset * num_heads * k_head_size + - head_idx * k_head_size; - - set_values(test_stream, memory, data_ptr, k_head_size, output_offset); - } - } - } - - return memory; - } - - template - static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { - mem_lock mem_ptr(mem, stream); - for (size_t i = 0; i < size; i++) { - mem_ptr[dst_offset + i] = vals[i]; - } - } - - static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { - const size_t total_elements_num = tokens_num * num_heads * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - // test code - // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); - - return data; - } - - static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { - const size_t total_elements_num = per_block ? rotated_blocks_num - : rotated_blocks_num * block_size; - auto data = rg.generate_random_1d(total_elements_num, 0, static_cast(max_tokens_num - 1)); - - return data; - } - - static std::vector generate_rotation_trig_lut_data(tests::random_generator& rg, size_t max_tokens_num, size_t k_head_size) { - const size_t total_elements_num = max_tokens_num * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - return data; - } - - static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, size_t size, bool expand_range = false) { - float min_value = std::numeric_limits::max(); - float max_value = std::numeric_limits::lowest(); - - for (size_t i = 0; i < size; i++) { - min_value = std::min((float)(data[i]), min_value); - max_value = std::max((float)(data[i]), max_value); - } - - float diff_value = 0.001; - if (max_value != min_value) - diff_value = max_value - min_value; - if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { - // compensate too small range - diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); - } - float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; - float zp = ((float)-min_value * scale) + std::numeric_limits::lowest(); - - std::vector quantized_data; - quantized_data.resize(size); - - auto convert_char_rte = [](float val) { - float rounded = std::nearbyint(val); - - if (rounded > 127.0f) { - return static_cast(127); - } else if (rounded < -128.0f) { - return static_cast(-128); - } else { - return static_cast(rounded); - } - }; - - for (size_t i = 0; i < size; i++) { - quantized_data[i] = convert_char_rte(data[i] * scale + zp); - } - - scale = 1.0f / scale; - - return std::make_tuple(quantized_data, scale, zp); - } -}; - namespace std { template <> struct hash { diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp new file mode 100644 index 00000000000000..067d7817a4a13e --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -0,0 +1,871 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paged_attention_gpu_test.hpp" +#include "random_generator.hpp" +#include "test_utils.h" + +using namespace cldnn; +using namespace ov::intel_gpu; +using namespace ::tests; + +namespace std { +template <> +struct hash { + uint64_t operator()(const ov::float16 __val) const { + return std::hash()(__val); + } +}; +} // namespace std + +struct xAttentionReference { + xAttentionReference(PagedAttentionManager& pam) : pam(pam), test_engine(pam.test_engine), test_stream(pam.test_stream) {} + + std::pair, std::vector> get_reference() { + std::vector ref_data_output; + std::vector ref_scores_output; + + for (size_t i = 0; i < pam.subsequence_descs.size(); i++) { + const auto& subsequence_desc = pam.subsequence_descs[i]; + const auto kv_seq_len = subsequence_desc.num_tokens + subsequence_desc.past_len; + + auto key_data = pam.key_data[i]; + if (pam.rotation_config.apply_rotation) { + auto blocks_start = pam.block_indices_begins[i]; + auto blocks_end = pam.block_indices_begins[i + 1]; + + std::vector block_indices(pam.block_indices.begin() + blocks_start, pam.block_indices.begin() + blocks_end); + + for (const auto& block_idx : block_indices) { + auto it = std::find(pam.rotated_block_indices.begin(), pam.rotated_block_indices.end(), block_idx); + if (it != pam.rotated_block_indices.end()) { + int index = std::distance(pam.rotated_block_indices.begin(), it); + int subsequence_rotated_block_idx = *it - blocks_start; + + rotate_block(key_data, + pam.rotation_deltas, + pam.rotation_trig_lut, + index, + subsequence_rotated_block_idx, + pam.num_heads, + pam.k_head_size, + pam.block_size, + pam.rotation_config.per_block); + } + } + } + + auto window_size = pam.has_score_aggregation ? pam.score_aggregation[i] : 1; + + auto subsequence_ref_results = run_reference(pam.query_data[i], + key_data, + pam.value_data[i], + subsequence_desc.num_tokens, + kv_seq_len, + pam.num_heads, + pam.k_head_size, + pam.v_head_size, + window_size, + pam.sliding_window_size, + pam.get_default_scale()); + + // concatenate all subsequences into one vector + ref_data_output.insert(ref_data_output.end(), subsequence_ref_results.first.begin(), subsequence_ref_results.first.end()); + ref_scores_output.insert(ref_scores_output.end(), subsequence_ref_results.second.begin(), subsequence_ref_results.second.end()); + } + + return {ref_data_output, ref_scores_output}; + } + +private: + // void print_tensor(const std::vector& data, size_t heads, size_t rows, size_t cols, const std::string& name) { + // std::cout << name << " (" << heads << "x" << rows << "x" << cols << "):\n"; + // for (size_t h = 0; h < heads; h++) { + // std::cout << " Head " << h << ":\n"; + // for (size_t i = 0; i < rows; i++) { + // for (size_t j = 0; j < cols; j++) { + // std::cout << static_cast(data[h * rows * cols + i * cols + j]) << " "; + // } + // std::cout << "\n"; + // } + // } + // } + + std::vector softmax_1(const std::vector& logits) { + std::vector out(logits.size()); + float max_val = *std::max_element(logits.begin(), logits.end()); + float sum = 0.0f; + for (float v : logits) + sum += std::exp(v - max_val); + for (size_t i = 0; i < logits.size(); i++) { + out[i] = static_cast(std::exp(logits[i] - max_val) / sum); + } + return out; + } + + std::vector safe_softmax(const std::vector& logits) { + std::vector probs(logits.size(), 0.0f); + float max_logit = -std::numeric_limits::infinity(); + for (float l : logits) + max_logit = std::max(max_logit, l); + if (std::isinf(max_logit)) + return probs; + + float sum_exp = 0.0f; + for (float l : logits) + sum_exp += std::exp(l - max_logit); + if (sum_exp == 0.0f) + return probs; + + for (size_t i = 0; i < logits.size(); ++i) + probs[i] = std::exp(logits[i] - max_logit) / sum_exp; + return probs; + } + + std::vector compute_sparse_causal_attention(const std::vector& Q_in, // [B, Tq, H, Dq] + const std::vector& K_in, // [B, Tk, H, Dk] + const std::vector& V_in, // [B, Tk, H, Dv] + size_t num_heads, + size_t num_queries, + size_t num_keys, + size_t qk_head_dim, + size_t v_head_dim, + const ov::reference::XAttentionRetainedBlockIndicesForAllHeads& retained_blocks_for_all_heads = {}, + float scale = 0.0f, + size_t block_size = 1) { + if (scale == 0.0f) + scale = 1.0f / std::sqrt(static_cast(qk_head_dim)); + + bool use_sparse = !retained_blocks_for_all_heads.empty(); + std::vector output(num_heads * num_queries * v_head_dim, ov::float16(0.0f)); + + std::cout << "---- compute_sparse_causal_attention ----\n"; + std::cout << "num_heads=" << num_heads << " num_queries=" << num_queries << " num_keys=" << num_keys << " qk_head_dim=" << qk_head_dim + << " v_head_dim=" << v_head_dim << " scale=" << scale << "\n"; + + // ======== permute Q,K,V from [B,T,H,D] → [H,T,D] ======== + std::vector Q(num_heads * num_queries * qk_head_dim); + std::vector K(num_heads * num_keys * qk_head_dim); + std::vector V(num_heads * num_keys * v_head_dim); + + for (size_t h = 0; h < num_heads; ++h) { + for (size_t t = 0; t < num_queries; ++t) { + for (size_t d = 0; d < qk_head_dim; ++d) { + Q[h * num_queries * qk_head_dim + t * qk_head_dim + d] = Q_in[t * num_heads * qk_head_dim + h * qk_head_dim + d]; + } + } + for (size_t t = 0; t < num_keys; ++t) { + for (size_t d = 0; d < qk_head_dim; ++d) { + K[h * num_keys * qk_head_dim + t * qk_head_dim + d] = K_in[t * num_heads * qk_head_dim + h * qk_head_dim + d]; + } + for (size_t d = 0; d < v_head_dim; ++d) { + V[h * num_keys * v_head_dim + t * v_head_dim + d] = V_in[t * num_heads * v_head_dim + h * v_head_dim + d]; + } + } + } + + // ======== Attention per head ======== + for (size_t h = 0; h < num_heads; ++h) { + const auto& retained_blocks = use_sparse ? retained_blocks_for_all_heads[h] : ov::reference::XAttentionRetainedBlockIndices{}; + + if (use_sparse) { + std::cout << "Head " << h << " retained blocks: "; + for (const auto& blk : retained_blocks) + std::cout << "(" << blk.first << "," << blk.second << ") "; + std::cout << std::endl; + } + + for (size_t q = 0; q < num_queries; ++q) { + std::vector logits(num_keys, -1e9f); + bool any_valid = false; + + for (size_t k = 0; k < num_keys; ++k) { + size_t q_block = q / block_size; + size_t k_block = k / block_size; + + if (use_sparse && retained_blocks.find({q_block, k_block}) == retained_blocks.end()) + continue; + if (k > q) + continue; // causal mask + + float score = 0.0f; + for (size_t d = 0; d < qk_head_dim; ++d) + score += static_cast(Q[h * num_queries * qk_head_dim + q * qk_head_dim + d]) * + static_cast(K[h * num_keys * qk_head_dim + k * qk_head_dim + d]); + logits[k] = score * scale; + any_valid = true; + } + + if (!any_valid) { + std::cout << "Head " << h << ", Query " << q << " has no valid keys -> zero output.\n"; + continue; + } + + auto probs = safe_softmax(logits); + + for (size_t d = 0; d < v_head_dim; ++d) { + float acc = 0.0f; + for (size_t k = 0; k <= q; ++k) { + if (use_sparse && retained_blocks.find({q / block_size, k / block_size}) == retained_blocks.end()) + continue; + acc += probs[k] * static_cast(V[h * num_keys * v_head_dim + k * v_head_dim + d]); + } + output[h * num_queries * v_head_dim + q * v_head_dim + d] = static_cast(acc); + } + } + } + + // ======== Debug summary ======== + std::cout << "Output preview (head0, first few queries):\n"; + for (size_t q = 0; q < std::min(4, num_queries); ++q) { + std::cout << " Q" << q << ": "; + for (size_t d = 0; d < std::min(8, v_head_dim); ++d) + std::cout << static_cast(output[q * v_head_dim + d]) << " "; + std::cout << "\n"; + } + + return output; + } + + std::pair, std::vector> run_reference(const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + int num_queries, + int num_keys, + int num_heads, + int k_head_size, + int v_head_size, + int window_size, + int sliding_window_size, + float scale, + double threshold = 0.8, + size_t block_size = 256, + size_t stride = 16) { + // --- 1. allocate memory --- + auto query_shape_bfyx = ov::PartialShape{1, num_queries, num_heads, k_head_size}; + auto key_shape_bfyx = ov::PartialShape{1, num_keys, num_heads, k_head_size}; + auto value_shape_bfyx = ov::PartialShape{1, num_keys, num_heads, v_head_size}; + + auto query_layout = layout{query_shape_bfyx, data_types::f16, format::bfyx}; + auto key_layout = layout{key_shape_bfyx, data_types::f16, format::bfyx}; + auto value_layout = layout{value_shape_bfyx, data_types::f16, format::bfyx}; + + OPENVINO_ASSERT(query_layout.count() == query_data.size()); + OPENVINO_ASSERT(key_layout.count() == key_data.size()); + OPENVINO_ASSERT(value_layout.count() == value_data.size()); + + auto query_mem = test_engine.allocate_memory(query_layout); + auto key_mem = test_engine.allocate_memory(key_layout); + auto value_mem = test_engine.allocate_memory(value_layout); + + set_values(query_mem, query_data); + set_values(key_mem, key_data); + set_values(value_mem, value_data); + + // std::cout << "=== query_data (bfyx layout) ===" << std::endl; + // for (int q = 0; q < num_queries; q++) { + // for (int h = 0; h < num_heads; h++) { + // std::cout << "q=" << q << ", h=" << h << ": ["; + // for (int d = 0; d < k_head_size; d++) { + // auto val = query_data[q * num_heads * k_head_size + h * k_head_size + d]; + // std::cout << static_cast(val) << (d + 1 < k_head_size ? ", " : ""); + // } + // std::cout << "]" << std::endl; + // } + // } + + // std::cout << "=== key_data (bfyx layout) ===" << std::endl; + // for (int k = 0; k < num_keys; k++) { + // for (int h = 0; h < num_heads; h++) { + // std::cout << "k=" << k << ", h=" << h << ": ["; + // for (int d = 0; d < k_head_size; d++) { + // auto val = key_data[k * num_heads * k_head_size + h * k_head_size + d]; + // std::cout << static_cast(val) << (d + 1 < k_head_size ? ", " : ""); + // } + // std::cout << "]" << std::endl; + // } + // } + + std::vector query_data_3d(num_heads * num_queries * k_head_size); + std::vector key_data_3d(num_heads * num_keys * k_head_size); + + for (int h = 0; h < num_heads; h++) { + for (int q = 0; q < num_queries; q++) { + for (int d = 0; d < k_head_size; d++) { + query_data_3d[h * num_queries * k_head_size + q * k_head_size + d] = query_data[q * num_heads * k_head_size + h * k_head_size + d]; + } + } + } + + for (int h = 0; h < num_heads; h++) { + for (int k = 0; k < num_keys; k++) { + for (int d = 0; d < k_head_size; d++) { + key_data_3d[h * num_keys * k_head_size + k * k_head_size + d] = key_data[k * num_heads * k_head_size + h * k_head_size + d]; + } + } + } + + ov::Shape query_shape_3d = {static_cast(num_heads), static_cast(num_queries), static_cast(k_head_size)}; + ov::Shape key_shape_3d = {static_cast(num_heads), static_cast(num_keys), static_cast(k_head_size)}; + + ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; + { + ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); + retained_blocks = selector.select_blocks(query_data_3d.data(), query_shape_3d, key_data_3d.data(), key_shape_3d); + + // std::cout << "=== C++ 选中 blocks ===" << std::endl; + // for (size_t h = 0; h < retained_blocks.size(); ++h) { + // std::cout << "Head " << h << " selected blocks: "; + // for (const auto& idx_pair : retained_blocks[h]) { + // std::cout << "(" << idx_pair.first << "," << idx_pair.second << ") "; + // } + // std::cout << std::endl; + // } + } + + // auto output = compute_sparse_causal_attention(query_data, + // key_data, + // value_data, + // num_heads, + // num_queries, + // num_keys, + // k_head_size, + // v_head_size, + // retained_blocks, + // 0.0f, + // block_size); + + // print_tensor(output, num_heads, num_queries, k_head_size, "Output"); + auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, block_size); + + topology topology; + topology.add(input_layout("query", query_layout), + input_layout("key", key_layout), + input_layout("value", value_layout), + data("mask", mask_mem), + permute("query_transposed", input_info("query"), {0, 2, 1, 3}), + permute("key_transposed", input_info("key"), {0, 2, 1, 3}), + permute("value_transposed", input_info("value"), {0, 2, 1, 3}), + gemm("qk_gemm", {input_info("query_transposed"), input_info("key_transposed")}, data_types::f16, false, true, scale), + eltwise("eltwise", {input_info("qk_gemm"), input_info("mask")}, eltwise_mode::sum), + softmax("softmax", input_info("eltwise"), -1), + gemm("qkv_gemm", {input_info("softmax"), input_info("value_transposed")}, data_types::f16, false, false), + permute("qkv_gemm_transposed", input_info("qkv_gemm"), {0, 2, 1, 3}), + reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16), + reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16)); + + ExecutionConfig config = get_test_default_config(test_engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + + network::ptr network = get_network(test_engine, topology, config, get_test_stream_ptr(), false); + network->set_input_data("query", query_mem); + network->set_input_data("key", key_mem); + network->set_input_data("value", value_mem); + + auto outputs = network->execute(); + + auto output_data_mem = outputs.at("output_data").get_memory(); + auto output_scores_mem = outputs.at("scores_data").get_memory(); + + return {get_output_data_vec(output_data_mem, num_queries, v_head_size, num_heads), + get_output_scores_vec(output_scores_mem, window_size, num_queries, num_keys, num_heads)}; + } + + std::vector get_output_scores_vec(memory::ptr scores_output, int window_size, int num_queries, int num_keys, int num_heads) { + OPENVINO_ASSERT(scores_output->count() == static_cast(num_heads * num_queries * num_keys)); + + std::vector output_scores(num_keys, 0); + mem_lock mem_ptr(scores_output, test_stream); + for (int row_idx = 0; row_idx < window_size; row_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int score_idx = 0; score_idx < num_keys; score_idx++) { + auto scores_offset = head_idx * num_queries * num_keys + (num_queries - window_size + row_idx) * num_keys + score_idx; + output_scores[score_idx] += mem_ptr[scores_offset]; + } + } + } + + return output_scores; + } + + std::vector get_output_data_vec(memory::ptr data_output, int num_queries, int k_head_size, int num_heads) { + OPENVINO_ASSERT(data_output->count() == static_cast(num_queries * num_heads * k_head_size)); + + std::vector output_data(data_output->count()); + mem_lock mem_ptr(data_output, test_stream); + for (size_t i = 0; i < data_output->count(); i++) + output_data[i] = mem_ptr[i]; + + return output_data; + } + + memory::ptr get_mask_mem_combined_multi_head(int num_queries, + int num_keys, + int num_heads, + int sliding_window_size, + const ov::reference::XAttentionRetainedBlockIndicesForAllHeads& retained_blocks, + int block_size) { + // mask layout: [1, num_heads, num_queries, num_keys] + auto mask_shape = ov::PartialShape{1, num_heads, num_queries, num_keys}; + auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; + auto mask_mem = test_engine.allocate_memory(mask_layout); + + mem_lock mem_ptr(mask_mem, test_stream); + + for (int h = 0; h < num_heads; h++) { + if (retained_blocks.empty() || retained_blocks[h].empty()) { + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + ov::float16 value = ov::float16(0.f); + if (sliding_window_size == 0) { + int past_len = num_keys - num_queries + 1; + if (j >= past_len + i) + value = std::numeric_limits::lowest(); + } else { + int sliding_left = num_keys - num_queries - sliding_window_size + 1; + int past_len = num_keys - num_queries + 1; + bool is_min; + if (num_queries == num_keys) { + is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; + } else { + is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; + } + if (is_min) + value = std::numeric_limits::lowest(); + } + mem_ptr[h * num_queries * num_keys + i * num_keys + j] = value; + } + } + continue; + } + + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + mem_ptr[h * num_queries * num_keys + i * num_keys + j] = std::numeric_limits::lowest(); + } + } + + for (int i = 0; i < num_queries; i++) { + int left_idx = 0; + int right_idx = 0; + + if (sliding_window_size == 0) { + int past_len = num_keys - num_queries + 1; + right_idx = past_len + i - 1; + left_idx = 0; + } else { + int sliding_left = num_keys - num_queries - sliding_window_size + 1; + int past_len = num_keys - num_queries + 1; + if (num_queries == num_keys) { + left_idx = sliding_left + i; + right_idx = i; + } else { + left_idx = sliding_left + i; + right_idx = past_len + i - 1; + } + } + + left_idx = std::max(0, left_idx); + right_idx = std::min(num_keys - 1, right_idx); + + for (const auto& [q_block_idx, k_block_idx] : retained_blocks[h]) { + int q_start = q_block_idx * block_size; + int q_end = std::min(q_start + block_size, num_queries); + int k_start = k_block_idx * block_size; + int k_end = std::min(k_start + block_size, num_keys); + + if (i < q_start || i >= q_end) + continue; + + for (int j = k_start; j < k_end; j++) { + if (j >= left_idx && j <= right_idx) { + mem_ptr[h * num_queries * num_keys + i * num_keys + j] = ov::float16(0.f); + } + } + } + } + } + + return mask_mem; + } + + void rotate_block(std::vector& cache_data, + std::vector rotation_deltas, + std::vector rotation_trig_lut_mem, + int rotated_block_idx, + int subsequence_rotated_block_idx, + int num_heads, + int k_head_size, + int block_size, + bool per_block) { + // cache_data shape: [1, num_tokens, num_heads, k_head_size] + int start_token_idx = subsequence_rotated_block_idx * block_size; + + for (int token_idx = 0; token_idx < block_size; token_idx++) { + auto rotation_deltas_offset = per_block ? rotated_block_idx : rotated_block_idx * block_size + token_idx; + auto rotation_trig_lut_idx = rotation_deltas[rotation_deltas_offset]; + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int k_head_size_idx = 0; k_head_size_idx < k_head_size / 2; k_head_size_idx++) { + auto input_offset = (start_token_idx + token_idx) * num_heads * k_head_size + head_idx * k_head_size + k_head_size_idx; + + auto cache_value_0 = cache_data[input_offset]; + auto cache_value_1 = cache_data[input_offset + k_head_size / 2]; + + ov::float16 rotation_value_cos = rotation_trig_lut_mem[rotation_trig_lut_idx * k_head_size + k_head_size_idx]; + ov::float16 rotation_value_sin = rotation_trig_lut_mem[rotation_trig_lut_idx * k_head_size + k_head_size_idx + k_head_size / 2]; + + cache_data[input_offset] = cache_value_0 * rotation_value_cos - cache_value_1 * rotation_value_sin; + cache_data[input_offset + k_head_size / 2] = cache_value_0 * rotation_value_sin + cache_value_1 * rotation_value_cos; + } + } + } + } + + PagedAttentionManager& pam; + cldnn::engine& test_engine; + cldnn::stream& test_stream; +}; + +template +struct xAttentionTest : public ::testing::TestWithParam { +public: + random_generator rg; + cldnn::engine& engine = get_test_engine(); + float tolerance = 2e-3; + + void SetUp() override { + rg.set_seed(GET_SUITE_NAME); + } + + void execute(T& p) { + PagedAttentionManager pam(rg, + get_test_engine(), + get_test_stream(), + p.subsequences, + p.num_heads, + p.k_head_size, + p.v_head_size, + p.block_size, + p.sliding_window_size, + p.kv_cache_compression, + p.key_cache_quant_mode, + p.scores_mode == ScoresMode::SNAPKV, + p.rotation_config); + + if (p.kv_cache_compression) + tolerance = 25e-3; + + auto query_mem = pam.get_query_memory(); + auto key_mem = pam.get_key_memory(); + auto value_mem = pam.get_value_memory(); + + auto key_cache_mem = pam.get_key_cache_memory(); + auto value_cache_mem = pam.get_value_cache_memory(); + + auto past_lens_mem = pam.get_past_lens_memory(); + auto subsequence_begins_mem = pam.get_subsequence_begins_memory(); + auto block_indices_mem = pam.get_block_indices_memory(); + auto block_indices_begins_mem = pam.get_block_indices_begins_memory(); + + auto scale_mem = pam.get_scale_memory(); + auto sliding_window_mem = pam.get_sliding_window_memory(); + auto alibi_mem = pam.get_alibi_memory(); + auto max_context_len_mem = pam.get_max_context_len_memory(); + + // scores calculation related memory buffers + auto score_aggregation_mem = pam.get_score_aggregation(); + + // cache rotation related memory buffers + auto rotated_block_indices_mem = pam.get_rotated_block_indices_memory(); + auto rotation_deltas_mem = pam.get_rotation_deltas_memory(); + auto rotation_trig_lut_mem = pam.get_rotation_trig_lut_memory(); + + auto xattention_threshold_mem = pam.get_xattention_threshold_memory(); + auto xattention_block_size_mem = pam.get_xattention_block_size_memory(); + auto xattention_stride_mem = pam.get_xattention_stride_memory(); + + auto query_layout = query_mem->get_layout(); + auto key_layout = key_mem->get_layout(); + auto value_layout = value_mem->get_layout(); + auto key_cache_layout = key_cache_mem->get_layout(); + auto value_cache_layout = value_cache_mem->get_layout(); + auto past_lens_layout = past_lens_mem->get_layout(); + auto subsequence_begins_layout = subsequence_begins_mem->get_layout(); + auto block_indices_layout = block_indices_mem->get_layout(); + auto block_indices_begins_layout = block_indices_begins_mem->get_layout(); + auto scale_layout = scale_mem->get_layout(); + auto sliding_window_layout = sliding_window_mem->get_layout(); + auto alibi_layout = alibi_mem->get_layout(); + auto max_context_len_layout = max_context_len_mem->get_layout(); + auto score_aggregation_window_layout = score_aggregation_mem->get_layout(); + auto rotated_block_indices_layout = rotated_block_indices_mem->get_layout(); + auto rotation_deltas_layout = rotation_deltas_mem->get_layout(); + auto rotation_trig_lut_layout = rotation_trig_lut_mem->get_layout(); + auto xattention_threshold_layout = xattention_threshold_mem->get_layout(); + auto xattention_block_size_layout = xattention_block_size_mem->get_layout(); + auto xattention_stride_layout = xattention_stride_mem->get_layout(); + + // make layouts dynamic + query_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); + key_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); + value_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.v_head_size }); +#if ENABLE_PA_CM_PATH + key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.k_head_size }); +#else + key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.k_head_size, p.block_size }); +#endif + value_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.v_head_size }); + past_lens_layout.set_partial_shape(ov::PartialShape{ -1 }); + subsequence_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); + block_indices_layout.set_partial_shape(ov::PartialShape{ -1 }); + block_indices_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); + score_aggregation_window_layout.set_partial_shape(ov::PartialShape{ -1 }); + rotated_block_indices_layout.set_partial_shape(ov::PartialShape{ -1 }); + rotation_deltas_layout.set_partial_shape(ov::PartialShape{ -1, -1 }); + rotation_trig_lut_layout.set_partial_shape(ov::PartialShape{ -1, p.k_head_size }); + xattention_threshold_layout.set_partial_shape(ov::PartialShape{ -1 }); + + if (p.dynamic_paddings) { + const auto padding_axis = 1; + const auto pad_before = p.k_head_size; + const auto pad_after = p.k_head_size * 2; + + query_layout.data_padding._dynamic_dims_mask[padding_axis] = 1; + + auto query_data_layout = query_mem->get_layout(); + auto padded_query_data_layout = query_data_layout; + padded_query_data_layout.data_padding._lower_size[padding_axis] = pad_before; + padded_query_data_layout.data_padding._upper_size[padding_axis] = pad_after; + + auto new_query_memory = get_test_engine().allocate_memory(padded_query_data_layout, false); + + mem_lock query_mem_lock(query_mem, get_test_stream()); + mem_lock new_query_mem_lock(new_query_memory, get_test_stream()); + + auto query_data_shape = query_data_layout.get_shape(); + for (size_t b = 0; b < query_data_shape[0]; b++) { + for (size_t f = 0; f < query_data_shape[1]; f++) { + auto input_offset = + query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); + auto output_offset = + padded_query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); + + new_query_mem_lock[output_offset] = query_mem_lock[input_offset]; + } + } + query_mem = new_query_memory; + } + + std::vector pa_inputs = { + input_info("query"), + input_info("key"), + input_info("value"), + input_info("key_cache"), + input_info("value_cache"), + input_info("past_lens"), + input_info("subsequence_begins"), + input_info("block_indices"), + input_info("block_indices_begins"), + input_info("scale"), + input_info("sliding_window"), + input_info("alibi"), + input_info("max_context_len"), + input_info("score_aggregation_window"), + input_info("rotated_block_indices"), + input_info("rotation_deltas"), + input_info("rotation_trig_lut_modified"), + input_info("xattention_threshold"), + input_info("xattention_block_size"), + input_info("xattention_stride"), + }; + + auto pa_prim = paged_attention("paged_attention", pa_inputs); + + pa_prim.k_head_size = p.k_head_size; + pa_prim.v_head_size = p.v_head_size; + pa_prim.kv_heads_num = p.num_heads; + pa_prim.heads_num = p.num_heads; + pa_prim.scale_val = pam.get_default_scale(); + pa_prim.has_alibi = false; + pa_prim.num_outputs = p.scores_mode == ScoresMode::DISABLED ? 1 : 2; + pa_prim.has_rotated_blocks = p.rotation_config.apply_rotation; + pa_prim.has_score_aggregation = p.scores_mode == ScoresMode::SNAPKV; + pa_prim.sliding_window = p.sliding_window_size; + pa_prim.is_key_by_channel = (p.key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL); + + topology topology; + + topology.add( + input_layout("query", query_layout), + input_layout("key", key_layout), + input_layout("value", value_layout), + input_layout("key_cache", key_cache_layout), + input_layout("value_cache", value_cache_layout), + input_layout("past_lens", past_lens_layout), + input_layout("subsequence_begins", subsequence_begins_layout), + input_layout("block_indices", block_indices_layout), + input_layout("block_indices_begins", block_indices_begins_layout), + input_layout("scale", scale_layout), + input_layout("sliding_window", sliding_window_layout), + input_layout("alibi", alibi_layout), + input_layout("max_context_len", max_context_len_layout), + input_layout("score_aggregation_window", score_aggregation_window_layout), + pa_prim, + reorder("output_data", input_info("paged_attention", 0), format::bfyx, data_types::f16) + ); + + if (p.scores_mode != ScoresMode::DISABLED) { + topology.add(reorder("output_scores", input_info("paged_attention", 1), format::bfyx, data_types::f16)); + } + + { + topology.add(input_layout("rotated_block_indices", rotated_block_indices_layout)); + topology.add(input_layout("rotation_deltas", rotation_deltas_layout)); + topology.add(input_layout("rotation_trig_lut", rotation_trig_lut_layout)); + + // add dummy activation operation to simulate an empty PA `rotation_trig_lut` buffer for shapes like [0, k_head_size] + topology.add(activation("rotation_trig_lut_modified", input_info("rotation_trig_lut"), activation_func::none)); + + topology.add(input_layout("xattention_threshold", xattention_threshold_layout)); + topology.add(input_layout("xattention_block_size", xattention_block_size_layout)); + topology.add(input_layout("xattention_stride", xattention_stride_layout)); + } + + ExecutionConfig config = get_test_default_config(get_test_engine()); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + // FlashAttn v1 or v2? + config.set_property(ov::intel_gpu::could_use_flashattn_v2(p.disable_flashattn_v2)); + config.set_property(ov::internal::key_cache_quant_mode(p.key_cache_quant_mode)); + network::ptr network = get_network(get_test_engine(), topology, config, get_test_stream_ptr(), false); + network->set_input_data("query", query_mem); + network->set_input_data("key", key_mem); + network->set_input_data("value", value_mem); + network->set_input_data("key_cache", key_cache_mem); + network->set_input_data("value_cache", value_cache_mem); + network->set_input_data("past_lens", past_lens_mem); + network->set_input_data("subsequence_begins", subsequence_begins_mem); + network->set_input_data("block_indices", block_indices_mem); + network->set_input_data("block_indices_begins", block_indices_begins_mem); + network->set_input_data("scale", scale_mem); + network->set_input_data("sliding_window", sliding_window_mem); + network->set_input_data("alibi", alibi_mem); + network->set_input_data("max_context_len", max_context_len_mem); + network->set_input_data("score_aggregation_window", score_aggregation_mem); + network->set_input_data("rotated_block_indices", rotated_block_indices_mem); + network->set_input_data("rotation_deltas", rotation_deltas_mem); + network->set_input_data("rotation_trig_lut", rotation_trig_lut_mem); + network->set_input_data("xattention_threshold", xattention_threshold_mem); + network->set_input_data("xattention_block_size", xattention_block_size_mem); + network->set_input_data("xattention_stride", xattention_stride_mem); + + auto outputs = network->execute(); + + cldnn::memory::ptr output_data_mem = nullptr; + cldnn::memory::ptr output_scores_mem = nullptr; + + output_data_mem = outputs.at("output_data").get_memory(); + if (p.scores_mode != ScoresMode::DISABLED) { + output_scores_mem = outputs.at("output_scores").get_memory(); + } + auto ref_data = xAttentionReference(pam).get_reference(); + for (size_t i = 0; i < ref_data.first.size(); i++) { + std::cout << i << "reference = " << ref_data.first[i] << std::endl; + } + compare(output_data_mem, output_scores_mem, ref_data); + } + + void compare(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) { + if (data_output_mem) { + ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); + mem_lock mem_ptr(data_output_mem, get_test_stream()); + for (size_t i = 0; i < data_output_mem->count(); i++) { + std::cout << i << ": result = " << mem_ptr[i] << ", reference = " << ref_data.first[i] << std::endl; + } + // for (size_t i = 0; i < data_output_mem->count(); i++) { + // ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance) << " at index=" << i; + // } + } + + if (scores_output_mem) { + ASSERT_EQ(scores_output_mem->count(), ref_data.second.size()); + mem_lock mem_ptr(scores_output_mem, get_test_stream()); + for (size_t i = 0; i < scores_output_mem->count(); i++) { + ASSERT_NEAR(mem_ptr[i], ref_data.second[i], tolerance) << " at index=" << i; + } + } + } +}; + +struct xattention_test_params { + std::vector subsequences; + int num_heads; + int k_head_size; + int v_head_size; + int block_size; + int sliding_window_size; + bool kv_cache_compression; + ov::internal::CacheQuantMode key_cache_quant_mode; + bool dynamic_paddings; + ScoresMode scores_mode; + CacheRotationDescriptor rotation_config; + bool disable_flashattn_v2; +}; + +class xattention_test : public xAttentionTest {}; +TEST_P(xattention_test, basic) { + auto p = GetParam(); + + execute(p); +} + +const auto ENABLE_CACHE_COMPRESSION = true; +const auto DISABLE_CACHE_COMPRESSION = false; +const auto DISABLE_SCORES = ScoresMode::DISABLED; +const auto ENABLE_SCORES = ScoresMode::LAST_TOKEN; +const auto ENABLE_SCORES_SNAPKV = ScoresMode::SNAPKV; +const auto PER_BLOCK_ROTATION = CacheRotationDescriptor{true, true}; +const auto PER_TOKEN_ROTATION = CacheRotationDescriptor{true, false}; +const auto DISABLE_ROTATION = CacheRotationDescriptor{false, false}; +const auto STATIC_INPUT_PAD = false; +const auto DYNAMIC_INPUT_PAD = true; +const auto ENABLE_FA_V2 = false; +const auto DISABLE_FA_V2 = true; + +INSTANTIATE_TEST_SUITE_P(smoke_xattention, + xattention_test, + ::testing::ValuesIn(std::vector{ + +#if ENABLE_PA_CM_PATH + /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, + token_size>=32, disable_mix_mode */ + // xattention_test_params{ {{32, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + +// xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, +// DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + +// xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, +// DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, +// ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, +// 1023}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, +// DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 127}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, +// STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 129}}, 2, 64, 64, 256, 0, +// DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token +// xattention_test_params{ {{1, 32}}, 28, 128, 128, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, +// DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token +#endif +})); diff --git a/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp b/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp new file mode 100644 index 00000000000000..60e5355894f62b --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp @@ -0,0 +1,658 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils.h" +#include "random_generator.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace cldnn; +using namespace ov::intel_gpu; +using namespace ::tests; + +/* +* PagedAttention inputs: +* [0]: query +* shape: [batch_size_in_tokens, num_heads * head_size], type: f16 +* [1]: key +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [2]: value  +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [3]: key_cache +* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16 +* [4]: value_cache +* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16 +* [5]: past_lens +* shape: [batch_size_in_sequences], type: i32 +* [6]: subsequence_begins +* shape: [batch_size_in_sequences + 1], type: i32 +* [7]: block_indices +* Shape: [num_blocks], type: i32 +* [8]: block_indices_begins +* Shape: [batch_size_in_sequences + 1], type: i32 +* [9]: scale, optional +* [10]: sliding_window, optional +* [11]: alibi_slopes, optional +* [12]: max_context_len +* shape: [], type: i32 +* [13]: score_aggregation_window​, optional​, shape: [batch_size_in_sequences] +* [14]: rotated_block_indices​, optional​ +* shape: [num_rotated_blocks]​, type: i32 +* [15]: rotation_deltas​, optional​ +* shape: [num_rotated_blocks, BLOCK_SIZE]​ || [num_rotated_blocks, 1]​, type: i32 +* [16]: rotation_trig_lut​, optional​ +* shape: [max_num_batched_tokens / BLOCK_SIZE, head_size]​ || [max_num_batched_tokens, head_size], type: f16 +*/ + + +enum class ScoresMode { + DISABLED = 0, + LAST_TOKEN, + SNAPKV +}; + +struct SubsequenceDescriptor { + int num_tokens; + int past_len; +}; + +struct CacheRotationDescriptor { + bool apply_rotation; + // configures 2nd dimension of rotation_deltas + // if per_block is true, single value is used for all tokens inside the block + // otherwise, each token uses an independent value + bool per_block; +}; + +struct PagedAttentionManager { + int num_heads; + int k_head_size; + int v_head_size; + int block_size; + int sliding_window_size; + bool kv_cache_compression; + ov::internal::CacheQuantMode key_cache_quant_mode; + bool has_score_aggregation; + CacheRotationDescriptor rotation_config; + std::vector subsequence_descs; + + // per-subsequence QKV inputs + std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} + std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} + std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} + + // common PA inputs + std::vector past_lens; + std::vector subsequence_begins; + std::vector block_indices; + std::vector block_indices_begins; + std::vector max_context_len; + std::vector score_aggregation_window; + + // score aggregation related inputs + std::vector score_aggregation; + + // rotation related inputs + std::vector rotated_block_indices; + std::vector rotation_deltas; + std::vector rotation_trig_lut; + + std::vector xattention_threshold; + std::vector xattention_block_size; + std::vector xattention_stride; + + cldnn::engine& test_engine; + cldnn::stream& test_stream; + tests::random_generator& rg; + + PagedAttentionManager(tests::random_generator& rg, + cldnn::engine& engine, + cldnn::stream& stream, + const std::vector& subsequence_descs, + int num_heads, + int k_head_size, + int v_head_size, + int block_size, + int sliding_window_size, + bool kv_cache_compression, + ov::internal::CacheQuantMode key_cache_quant_mode, + bool has_score_aggregation, + CacheRotationDescriptor rotation_config) + : num_heads(num_heads) + , k_head_size(k_head_size) + , v_head_size(v_head_size) + , block_size(block_size) + , sliding_window_size(sliding_window_size) + , kv_cache_compression(kv_cache_compression) + , key_cache_quant_mode(key_cache_quant_mode) + , has_score_aggregation(has_score_aggregation) + , rotation_config(rotation_config) + , subsequence_descs(subsequence_descs) + , test_engine(engine) + , test_stream(stream) + , rg(rg) { + // init subsequence_begins and block_indices_begins + subsequence_begins.push_back(0); + block_indices_begins.push_back(0); + + int max_len = 0; + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + const auto& subsequence_desc = subsequence_descs[i]; + max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); + + query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); + key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); + value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); + + past_lens.push_back(subsequence_desc.past_len); + int subsequence_start_pos = subsequence_begins[i]; + int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens; + subsequence_begins.push_back(subsequence_end_pos); + + int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len; + int required_blocks = ceil_div(subsequence_length, block_size); + int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1; + int end_block_idx = start_block_idx + required_blocks; + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + block_indices.push_back(block_idx); + } + + int block_indices_start_pos = block_indices_begins[i]; + int block_indices_end_pos = block_indices_start_pos + required_blocks; + block_indices_begins.push_back(block_indices_end_pos); + } + max_context_len.push_back(max_len); + + if (rotation_config.apply_rotation) { + // iterate over KV-cache blocks and apply cache rotation to every second + // fully occupied block + for (size_t i = 0; i < subsequence_descs.size(); i++) { + const auto& subsequence_desc = subsequence_descs[i]; + int past_len = subsequence_desc.past_len; + int start_block_idx = block_indices_begins[i]; + for (int block_idx = 1; block_idx < past_len / block_size; block_idx++) { + if (block_idx % 2 != 0) { + rotated_block_indices.push_back(start_block_idx + block_idx); + } + } + } + + if (!rotated_block_indices.empty()) { + rotation_deltas = generate_rotation_deltas_data(rg, + max_context_len[0], + rotated_block_indices.size(), + block_size, + rotation_config.per_block); + rotation_trig_lut = generate_rotation_trig_lut_data(rg, max_context_len[0], k_head_size); + } + } + + if (has_score_aggregation) { + for (const auto& subsequence_desc : subsequence_descs) { + const auto max_tokens = 10; + auto max_window_size = std::min(subsequence_desc.num_tokens, max_tokens); + auto window_size = rg.generate_random_val(1, max_window_size); + score_aggregation.push_back(window_size); + } + } + } + + memory::ptr get_query_memory() { + return get_QKV_memory(query_data, k_head_size, false); + } + + memory::ptr get_key_memory() { + return get_QKV_memory(key_data, k_head_size, true); + } + + memory::ptr get_value_memory() { + return get_QKV_memory(value_data, v_head_size, true); + } + +#if ENABLE_PA_CM_PATH + memory::ptr get_key_cache_memory() { + auto key_cache_dt = data_types::f16; + auto adjusted_head_size = k_head_size; + if (kv_cache_compression) { + key_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; + auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(key_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * v_head_size + + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + +#else + memory::ptr get_key_cache_memory() { + auto key_cache_dt = data_types::f16; + auto adjusted_head_size = k_head_size; + auto adjusted_block_size = block_size; + if (kv_cache_compression) { + key_cache_dt = data_types::i8; + if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { + adjusted_block_size += 4; + } else { + adjusted_head_size += 4; + } + } + + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, adjusted_head_size, adjusted_block_size }; + auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(key_cache_layout); + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + // quantize by channel + if (kv_cache_compression && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { + std::vector token_block(block_size); + for (int token_idx = 0; token_idx < last_token_idx; ++token_idx) { + size_t input_token_offset = block_idx * block_size + token_idx; + token_block[token_idx] = *(key_data[i].data() + input_token_offset * num_heads * k_head_size + head_idx * k_head_size + k_head_size_idx); + } + auto [quantized_data, scale, zp] = quantize_data(token_block.data(), last_token_idx, true); + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * adjusted_block_size + + head_idx * adjusted_head_size * adjusted_block_size; + size_t output_offset = output_block_offset + + k_head_size_idx * adjusted_block_size; + set_values(test_stream, memory, quantized_data.data(), last_token_idx, output_offset); + size_t comp_offset = (output_offset + block_size)/2; + set_values(test_stream, memory, &scale, 1, comp_offset); + set_values(test_stream, memory, &zp, 1, comp_offset + 1); + } + } + } + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + if (kv_cache_compression) { + if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_TOKEN) { + // quantize by token + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + // shape: [num_blocks, num_heads, adjusted_head_size, block_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * block_size + + head_idx * adjusted_head_size * block_size; + + auto [quantized_data, scale, zp] = quantize_data(data_ptr, k_head_size); + for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { + auto quantized_data_ptr = quantized_data.data() + k_head_size_idx; + + size_t output_offset = output_block_offset + + k_head_size_idx * block_size + + token_idx; + + set_values(test_stream, memory, quantized_data_ptr, 1, output_offset); + } + size_t comp_offset = (output_block_offset + k_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } + } else { + for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * k_head_size + + head_idx * k_head_size + k_head_size_idx; + + // shape: [num_blocks, num_heads, k_head_size, block_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * k_head_size * block_size + + head_idx * k_head_size * block_size + + k_head_size_idx * block_size + + token_idx; + + set_values(test_stream, memory, data_ptr, 1, output_offset); + } + } + } + } + } + } + } + + return memory; + } +#endif + + memory::ptr get_value_cache_memory() { + auto value_cache_dt = data_types::f16; + auto adjusted_head_size = v_head_size; + if (kv_cache_compression) { + value_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; + auto value_cache_layout = layout{ value_cache_shape, value_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(value_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = value_data[i].data() + + input_token_offset * num_heads * v_head_size + + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + + memory::ptr get_past_lens_memory() { + return get_memory_from_vec(past_lens); + } + + memory::ptr get_subsequence_begins_memory() { + return get_memory_from_vec(subsequence_begins); + } + + memory::ptr get_block_indices_memory() { + return get_memory_from_vec(block_indices); + } + + memory::ptr get_block_indices_begins_memory() { + return get_memory_from_vec(block_indices_begins); + } + + memory::ptr get_scale_memory() { + std::vector scale = { ov::float16(get_default_scale()) }; + return get_memory_from_vec(scale); + } + + memory::ptr get_sliding_window_memory() { + std::vector sliding_window = { 0 }; + return get_memory_from_vec(sliding_window); + } + + memory::ptr get_alibi_memory() { + std::vector alibi; + return get_memory_from_vec(alibi); + } + + memory::ptr get_max_context_len_memory() { + return get_memory_from_vec(max_context_len); + } + + memory::ptr get_score_aggregation() { + return get_memory_from_vec(score_aggregation); + } + + memory::ptr get_rotated_block_indices_memory() { + return get_memory_from_vec(rotated_block_indices); + } + + memory::ptr get_rotation_deltas_memory() { + auto mem = get_memory_from_vec(rotation_deltas); + auto layout = mem->get_layout(); + auto last_dim = rotation_config.per_block ? 1 : block_size; + layout.set_partial_shape(ov::PartialShape{ static_cast(rotated_block_indices.size()), last_dim }); + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_rotation_trig_lut_memory() { + auto mem = get_memory_from_vec(rotation_trig_lut); + auto layout = mem->get_layout(); + layout.set_partial_shape(ov::PartialShape{ max_context_len[0], k_head_size }); + + if (rotated_block_indices.empty()) { + auto empty_layout = mem->get_layout(); + empty_layout.set_partial_shape(ov::PartialShape{ 0, k_head_size }); + return test_engine.reinterpret_buffer(*mem, empty_layout); + } + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_xattention_threshold_memory() { + auto mem = get_memory_from_vec(xattention_threshold); + auto layout = mem->get_layout(); + layout.set_partial_shape(ov::PartialShape{ 1 }); + + if (xattention_threshold.empty()) { + auto empty_layout = mem->get_layout(); + empty_layout.set_partial_shape(ov::PartialShape{ 0 }); + return test_engine.reinterpret_buffer(*mem, empty_layout); + } + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_xattention_block_size_memory() { + return get_memory_from_vec(xattention_block_size); + } + + memory::ptr get_xattention_stride_memory() { + return get_memory_from_vec(xattention_stride); + } + + float get_default_scale() { + return static_cast(1.f / std::sqrt(k_head_size)); + } + +private: + template + memory::ptr get_memory_from_vec(std::vector& input_data) { + auto data_size = input_data.empty() ? 1 : input_data.size(); + auto shape = ov::PartialShape{ static_cast(data_size) }; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + auto memory = test_engine.allocate_memory(layout); + + if (input_data.empty()) { + auto shape = ov::PartialShape{0}; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + return test_engine.reinterpret_buffer(*memory, layout); + } + + set_values(test_stream, memory, input_data.data(), input_data.size(), 0); + + return memory; + } + + memory::ptr get_QKV_memory(std::vector>& input_data, int k_head_size, bool skip_past_len) { + int total_tokens = 0; + for (const auto& subsequence_desc : subsequence_descs) + total_tokens += subsequence_desc.num_tokens; + + auto query_shape = ov::PartialShape{ total_tokens, num_heads * k_head_size }; + auto query_layout = layout{ query_shape, data_types::f16, format::bfyx }; + auto memory = test_engine.allocate_memory(query_layout); + + for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { + for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = token_idx; + // as generated data stored in vectors includes past_len, ignore it for KV inputs + if (skip_past_len) + input_token_offset += subsequence_descs[subsequence_idx].past_len; + + ov::float16* data_ptr = input_data[subsequence_idx].data() + + input_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + + size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; + size_t output_offset = output_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + + set_values(test_stream, memory, data_ptr, k_head_size, output_offset); + } + } + } + + return memory; + } + + template + static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { + mem_lock mem_ptr(mem, stream); + for (size_t i = 0; i < size; i++) { + mem_ptr[dst_offset + i] = vals[i]; + } + } + + static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { + const size_t total_elements_num = tokens_num * num_heads * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + // test code + // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); + + return data; + } + + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { + const size_t total_elements_num = per_block ? rotated_blocks_num + : rotated_blocks_num * block_size; + auto data = rg.generate_random_1d(total_elements_num, 0, static_cast(max_tokens_num - 1)); + + return data; + } + + static std::vector generate_rotation_trig_lut_data(tests::random_generator& rg, size_t max_tokens_num, size_t k_head_size) { + const size_t total_elements_num = max_tokens_num * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + return data; + } + + static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, size_t size, bool expand_range = false) { + float min_value = std::numeric_limits::max(); + float max_value = std::numeric_limits::lowest(); + + for (size_t i = 0; i < size; i++) { + min_value = std::min((float)(data[i]), min_value); + max_value = std::max((float)(data[i]), max_value); + } + + float diff_value = 0.001; + if (max_value != min_value) + diff_value = max_value - min_value; + if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { + // compensate too small range + diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); + } + float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; + float zp = ((float)-min_value * scale) + std::numeric_limits::lowest(); + + std::vector quantized_data; + quantized_data.resize(size); + + auto convert_char_rte = [](float val) { + float rounded = std::nearbyint(val); + + if (rounded > 127.0f) { + return static_cast(127); + } else if (rounded < -128.0f) { + return static_cast(-128); + } else { + return static_cast(rounded); + } + }; + + for (size_t i = 0; i < size; i++) { + quantized_data[i] = convert_char_rte(data[i] * scale + zp); + } + + scale = 1.0f / scale; + + return std::make_tuple(quantized_data, scale, zp); + } +}; From 13b11227575b7043e870840bb33b106fddb07691 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Fri, 10 Oct 2025 09:20:56 +0800 Subject: [PATCH 35/96] Optimize single_token_finalization kernel with fixed unroll --- .../src/graph/impls/cm/pa_single_token.cm | 115 +----------------- .../impls/cm/pa_single_token_finalization.cm | 47 +++++-- 2 files changed, 37 insertions(+), 125 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index 84bf0accb61383..0537797766bb66 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -28,59 +28,8 @@ #define VNNI_WIDTH 2 #define REG_K (SystolicDepth * VNNI_WIDTH) #define REG_M RepeatCount - -#if 0 -#define HEADS_NUM -#define KV_HEADS_NUM -#define HEAD_SIZE -#define SCALE_FACTOR -#define KV_BLOCK_SIZE -#define KV_PARTITION_SIZE -#define Q_STEP -#define KV_STEP -#define WG_SIZE -#define XE_ARCH -#define KV_CACHE_COMPRESSION -#endif - #define KV_PARTITION_STEP_NUM (KV_PARTITION_SIZE / KV_STEP) -#define DEBUG_ENABLE 0 -#if DEBUG_ENABLE -template -void show(matrix mat) { - for(int m = 0; m < M; m ++) { - printf("\t["); - for(int n = 0; n < N; n ++) { - printf("%8.4f,", mat[m][n]); - } - printf("],\n"); - } - printf("]\n"); -} - -template -void show_u8(matrix mat) { - for(int m = 0; m < M; m ++) { - printf("\t["); - for(int n = 0; n < N; n ++) { - printf("%4d", mat[m][n]); - } - printf("],\n"); - } - printf("]\n"); -} - -template -void show(vector vec) { - printf("\t["); - for(int n = 0; n < N; n ++) { - printf("%8.4f,", vec[n]); - } - printf("]\n"); -} -#endif - #define Q_SLICE_NUM (HEADS_NUM / KV_HEADS_NUM) #if Q_SLICE_NUM > 8 || Q_SLICE_NUM == 1 @@ -151,7 +100,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const uint start_block_idx = block_indices_begins[seq_idx] + kv_partition_idx * (KV_PARTITION_SIZE / KV_BLOCK_SIZE); if(kv_partition_idx * KV_PARTITION_SIZE > kv_len) { - // printf("WG exit: kv_partition_idx=%d, KV_PARTITION_SIZE=%d, kv_len=%d\n", kv_partition_idx, KV_PARTITION_SIZE, kv_len); return; } const uint total_blocks_num = (kv_len + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE; @@ -168,11 +116,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( cm_svm_block_read((svmptr_t)(query + qo_offset), Qmat.format()); #endif - //if(kv_head_num_idx==0 && kv_partition_idx == 0) { - // printf("Qmat loaded, kv_head_num_idx=%d\n", kv_head_num_idx); - // show(Qmat); - //} - constexpr uint per_kv_block_element_num = KV_BLOCK_SIZE * KV_HEADS_NUM * (HEAD_SIZE + KV_SCALE_ZP_SIZE / sizeof(KV_ELEMENT_TYPE)); // 4 bytes: scale/zp uint block_num = KV_PARTITION_SIZE / KV_BLOCK_SIZE; @@ -198,9 +141,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint kv_base_offset = blk_indices * per_kv_block_element_num + kv_head_num_idx * (per_kv_block_element_num / KV_HEADS_NUM); uint kv_scale_zp_offset = kv_base_offset + KV_BLOCK_SIZE * HEAD_SIZE; // scale/zp offset - // printf("seq_idx = %d, head_num_idx = %d, kv_partition_idx = %d, start_block_idx = %d, block_idx = %d, blk_indices = %d, KV_PARTITION_SIZE = %d, KV_BLOCK_SIZE = %d, total_blocks_num = %d, kv_pitch = %d, kv_base_offset = %d\n", - // seq_idx, head_num_idx, kv_partition_idx, start_block_idx, block_idx, blk_indices, KV_PARTITION_SIZE, KV_BLOCK_SIZE, total_blocks_num, kv_pitch, kv_base_offset); - #if USE_LSC_BLOCK_2D_DESC #if KV_CACHE_COMPRESSION // Transpose only support dword and qwork @@ -239,9 +179,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #endif for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += KV_STEP, ki++) { - // auto rSvec = rS[ki].format(); - // uint kv_offset_y = kv_pos; - #if KV_CACHE_COMPRESSION vector temp_scale, temp_zp; temp_scale.select(0) = scale_vec.select(kv_pos); @@ -275,13 +212,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( quant_dst.row(r+1) = quant_src.select<2,1,8,2>(r,1); } - #if DEBUG_ENABLE - printf("Kt_quant_temp: k = %d\n", k); - show_u8(Kt_quant_temp.format()); - printf("Kt_quant_vnni: k = %d\n", k); - show_u8(Kt_quant.format()); - #endif - #pragma unroll for(int r = 0; r < REG_K; r++) { Kt[r] = Kt_quant[r] - temp_zp.format()[r%2]; //vector - vector @@ -290,10 +220,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #else cm_load(Kt.format(), b2dK.set_block_y(kv_pos)); #endif - //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { - // printf("Kt: k = %d\n", k); - // show(Kt.format()); - //} #else matrix temp; uint cur_kv_offset = kv_offset + kv_pos * kv_stride + k * 2;// uint --> half @@ -319,10 +245,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( Kt.format(), Qmat_data.format()); rS.select(0, ki*REG_N) += rS_data; - - //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { - // show(rS_data); - //} #else #pragma unroll for(int qi = 0; qi < Q_SLICE_NUM; qi ++) { @@ -338,11 +260,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } } - //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { - // printf("rS:\n"); - // show(rS); - //} - // online softmax vector cur_sum = 0.0f; vector cur_lse = 0.0f; @@ -399,11 +316,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( cur_sum[qi] = cm_sum(rPv[0]); } - //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { - // printf("Pmat:\n"); - // show(Pmat); - //} - //# rO = P * V #if Q_RepeatCount != 1 matrix Omat = 0; @@ -431,10 +343,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint kv_x1 = HEAD_SIZE*sizeof(half); uint kv_y1 = KV_BLOCK_SIZE; #endif - - //if(kv_partition_idx==kv_partition_num - 1 && head_num_idx == HEADS_NUM - 1) { - // printf("leftover_size = %d, leftover_aligned_size = %d, XE_ARCH = %d, KV_BLOCK_SIZE = %d\n", leftover_size, leftover_aligned_size, XE_ARCH, KV_BLOCK_SIZE); - //} + uint kv_pos_end = KV_BLOCK_SIZE; if(block_idx == block_num - 1 && leftover_size > 0) { kv_pos_end = leftover_size % KV_BLOCK_SIZE; @@ -474,20 +383,11 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( matrix Vt_quant; cm_load(Vt_quant.format(), b2dV.set_block_y(kv_pos)); - #if DEBUG_ENABLE - //printf("Vt_quant: k = %d\n", k); - //show_u8(Vt_quant.format()); - //show(temp_scale); - //show(temp_zp); - //printf("\n"); - #endif - #pragma unroll for(int r = 0; r < REG_K; r++) { VmatNormal[r] = Vt_quant[r] - temp_zp[r]; // vector - scalar VmatNormal[r] = cm_mul(VmatNormal[r], temp_scale[r]); // vector * scalar } - // show(VmatNormal.format()); if(kv_pos_end - kv_pos < KV_STEP) { #pragma unroll @@ -499,10 +399,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #else cm_load(Vmat[0].format(), b2dV.set_block_y(kv_pos)); #endif - #if DEBUG_ENABLE - //printf("Vmat: k = %d\n", k); - //show(Vmat.format()); - #endif #else matrix temp; uint cur_kv_offset = kv_offset + kv_pos * kv_stride + k; @@ -532,19 +428,10 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( Pmat_slice[ki].format()); } #endif - //if(kv_partition_idx==kv_partition_num - 1 && head_num_idx == 27) { - // printf("Omat[%d][%d]:\n",kv_pos, k); - // show(Omat); - //} } } } - //if(kv_partition_idx==kv_partition_num - 1 && kv_head_num_idx == KV_HEADS_NUM - 1) { - // printf("Omat:\n"); - // show(Omat); - //} - //# save Output for (int qi = 0; qi < Q_SLICE_NUM; qi++) { matrix cur_O_f32; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm index 66d8acf6df0ef2..dcf0a17327e201 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -1,12 +1,18 @@ -// Copyright (C) 2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#if 0 -#define HEADS_NUM -#define HEAD_SIZE -#define REDUCE_SPLIT_SIZE -#endif +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ //cm_sdpa_2nd_reduce extern "C" _GENX_MAIN_ void KERNEL_NAME( @@ -30,7 +36,16 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( for(int k = 1; k < kv_partition_num; k ++) { lse_max = cm_max(lse_vec[k], lse_max); } - for(int k = 0; k < kv_partition_num; k ++) { + + int iter = kv_partition_num / 16; + for(int k = 0; k < iter * 16; k += 16) { + #pragma unroll + for(int ki = k; ki < k + 16; ki ++) { + float lse_value = cm_exp((lse_vec[ki] - lse_max)*log2e); + total_lse += lse_value; + } + } + for(int k = iter * 16; k < kv_partition_num; k ++) { float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); total_lse += lse_value; } @@ -40,7 +55,17 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( matrix out_mat = 0; matrix data_mat; uint input_offset = batch * total_partition_num * HEAD_SIZE + head * kv_partition_num * HEAD_SIZE + offset; - for(int k = 0; k < kv_partition_num; k ++) { + + for(int k = 0; k < iter * 16; k += 16) { + #pragma unroll + for(int ki = k; ki < k + 16; ki ++) { + cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); + input_offset += HEAD_SIZE; + float lse_value = cm_exp((lse_vec[ki] - lse_max)*log2e); + out_mat_f32 += cm_mul(data_mat, (float)(lse_value/total_lse)); + } + } + for(int k = iter * 16; k < kv_partition_num; k ++) { cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); input_offset += HEAD_SIZE; float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); From 24d6b80fcfd67dd4a688a22789bdc67c48cff94e Mon Sep 17 00:00:00 2001 From: Chen Peter Date: Fri, 10 Oct 2025 20:42:11 +0800 Subject: [PATCH 36/96] Fix win build --- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 0657488e706f4e..07d23590204263 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1250,9 +1250,9 @@ const auto DYNAMIC_INPUT_PAD = true; const auto ENABLE_FA_V2 = false; const auto DISABLE_FA_V2 = true; -INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ #if ENABLE_PA_CM_PATH +INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ paged_attention_test_params{ {{32, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{1024, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long From 326fc4d30d3d4b8acf09e9c949b84790aed5a85b Mon Sep 17 00:00:00 2001 From: Chen Peter Date: Fri, 10 Oct 2025 20:42:22 +0800 Subject: [PATCH 37/96] Fix win build --- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 07d23590204263..3b4ce24b938c07 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1372,5 +1372,5 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{5, 10}}, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{5, 10}}, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 64, 64, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token -#endif })); +#endif From 508fab3225325a323cd5371ad2cc1b5c01be4d67 Mon Sep 17 00:00:00 2001 From: Chen Peter Date: Fri, 10 Oct 2025 20:42:32 +0800 Subject: [PATCH 38/96] Fix win build --- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 3b4ce24b938c07..1a30ac20eb4804 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1263,7 +1263,9 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{1, 127}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 129}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 32}}, 28, 128, 128, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token +})); #else +INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ /* with scores output, use SnapKV */ paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{36, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token From 73669d3ae118f52a85f6a61d90184f7673100fe5 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Sat, 11 Oct 2025 09:21:09 +0800 Subject: [PATCH 39/96] Enable CM PA only in case of XAttention been enabled. --- .../include/intel_gpu/primitives/paged_attention.hpp | 2 -- .../intel_gpu/src/graph/impls/cm/paged_attention.hpp | 7 +++++++ .../src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp | 4 ---- .../intel_gpu/src/graph/registry/paged_attention_impls.cpp | 7 ++++++- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index 2d5a97c49f7166..c750b01ba353d3 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -10,8 +10,6 @@ namespace cldnn { -#define ENABLE_PA_CM_PATH 1 - struct paged_attention : public primitive_base { CLDNN_DECLARE_PRIMITIVE(paged_attention) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp index 45c956ae54b173..85cecbbfe39dc7 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp @@ -30,6 +30,13 @@ struct PagedAttentionImplementationManager : public ImplementationManager { ov::element::i8, }; + // Enable CM PA only in case of XAttention been enabled. May decouple them in future. + auto desc = node.as().get_primitive(); + if (!desc->has_xattention) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false because we enable CM PA when XAttention is enabled. " << std::endl; + return false; + } + auto& engine = node.get_program().get_engine(); const auto& config = node.get_program().get_config(); const auto& info = engine.get_device_info(); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp index d2ccc46bf01e37..8fccefbc9e5eae 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp @@ -25,13 +25,9 @@ struct PagedAttentionOpt : public ImplementationManager { ov::element::f16, }; static constexpr std::array supported_kv_types = { - #if ENABLE_PA_CM_PATH - ov::element::f32, - #else ov::element::f32, ov::element::f16, ov::element::i8, - #endif }; const auto& q_layout = node.get_input_layout(0); diff --git a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp index 738cf4c9e59d95..8283553ad73700 100644 --- a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp +++ b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp @@ -8,6 +8,9 @@ #if OV_GPU_WITH_OCL #include "impls/ocl_v2/sdpa/paged_attention_opt.hpp" +#endif + +#if OV_GPU_WITH_CM #include "impls/cm/paged_attention.hpp" #endif @@ -18,8 +21,10 @@ using namespace cldnn; const std::vector>& Registry::get_implementations() { static const std::vector> impls = { +#if OV_GPU_WITH_CM + OV_GPU_CREATE_INSTANCE_CM(cm::PagedAttentionImplementationManager, shape_types::any) +#endif OV_GPU_CREATE_INSTANCE_OCL(ocl::PagedAttentionOpt, shape_types::any) - OV_GPU_CREATE_INSTANCE_OCL(cm::PagedAttentionImplementationManager, shape_types::any) }; return impls; From 45bedf3c0011464f29de1976a9cb6572bfa7f7e2 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Sat, 11 Oct 2025 10:09:42 +0800 Subject: [PATCH 40/96] pass xattention threshold from genai --- .../src/graph/impls/cm/paged_attention_gen.cpp | 14 ++++++++------ .../src/graph/impls/cm/paged_attention_gen.hpp | 2 +- .../src/graph/include/paged_attention_inst.h | 1 + 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 0f06595b9487a8..5c87a64860d1cd 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -208,11 +208,13 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx) { return paged_attention_past_len; } -const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx) { - (void) seq_idx; // TODO - - static const char* env = std::getenv("OV_GPU_XATTN_THRESH"); - static const float thresh = env ? std::strtof(env, nullptr) : 0.9; +// TODO: change xattn_thresh from scaler to memory... once we remove the converter node +// between parameter node "xattention_threshold.xxx" and paged_attention node. +const float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_idx) { + const auto& input_mem = params.memory_deps; + const auto threshold_mem = input_mem.at(PagedAttentionInputIdx::XATTENTION_THRESHOLD); + mem_lock lock(threshold_mem, *params.strm); // converted + const auto thresh = static_cast(lock[seq_idx]); return thresh; } @@ -944,7 +946,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { const uint32_t q_block = ceil_div(q_stride, sum_per_n_token_in_block); const uint32_t k_block = ceil_div(k_stride, sum_per_n_token_in_block); - const float xattn_thresh = get_xattn_thresh(params, 0); // TODO: seq_idx + const float xattn_thresh = get_xattn_thresh(params); wgs.global = {q_block_pad, heads_num, 1}; wgs.local = {1, 1, 1}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 314032f2fd0afe..889f767df3990a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -62,7 +62,7 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); size_t get_partition_size(const bool has_xattention); size_t get_partition_num(const size_t kv_len, const bool has_xattention); -const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx); +const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx = 0); inline size_t get_xattn_block_size(const kernel_impl_params& impl_param) { return impl_param.get_program().get_config().get_xattention_block_size(); } diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index a832ee6bbf48d4..cfdcc6750ec9f9 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -22,6 +22,7 @@ struct typed_program_node : public typed_program_node_base get_lockable_input_ids() const override { std::set input_ports = { PagedAttentionInputIdx::PAST_LENS, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS, + PagedAttentionInputIdx::XATTENTION_THRESHOLD, PagedAttentionInputIdx::MAX_CONTEXT_LEN }; // debug From b7a9a8b176fb0286053816dc4c2f338eb2943960 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Sat, 11 Oct 2025 10:21:35 +0800 Subject: [PATCH 41/96] xattention_block_size unconfigurable --- .../include/intel_gpu/runtime/internal_properties.hpp | 1 - .../intel_gpu/include/intel_gpu/runtime/options.inl | 1 - .../intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp | 7 ++++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp index 433e5da8c790b9..536eda8e7b06d5 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp @@ -172,7 +172,6 @@ static constexpr Property asym_dynamic_quantiz static constexpr Property shape_predictor_settings{"GPU_SHAPE_PREDICTOR_SETTINGS"}; static constexpr Property, ov::PropertyMutability::RW> load_dump_raw_binary{"GPU_LOAD_DUMP_RAW_BINARY"}; static constexpr Property could_use_flashattn_v2{"GPU_COULD_USE_FLASHATTN_V2"}; -static constexpr Property xattention_block_size{"GPU_XATTN_BLOCK_SIZE"}; static constexpr Property dynamic_quantization_group_size_max{"GPU_DYNAMIC_QUANTIZATION_GROUP_SIZE_MAX"}; static constexpr Property validate_output_buffer{"VALIDATE_OUTPUT_BUFFER"}; } // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl index 2f923fb8c1637f..1546ea2a7c7570 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl @@ -55,7 +55,6 @@ OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, asym_dynamic_quantization, fals OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, could_use_flashattn_v2, true, "Enable/Disable SDPA primitive executing with FlashAttenV2 online softmax tricks.") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, dynamic_quantization_threshold, 64, "Apply dynamic quantization only when batch size is larger than this value in OneDNN") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, weightless_attr, nullptr, "Used to configure ov::WeightlessCacheAttribute for constants that are not loaded from a .bin file. This typically applies to non-IR inputs (e.g., ORT)") -OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, xattention_block_size, 128, "block size for X-Attention sparse.") OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, help, false, "Print help message for all config options") OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, verbose, 0, "Enable logging for debugging purposes. The higher value the more verbose output. 0 - Disabled, 4 - Maximum verbosity") diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 889f767df3990a..5ee4455e547dfe 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -30,7 +30,7 @@ constexpr auto get_pa_build_options() { return " -cmc -Qxcm_register_file_size=256"; } -// BLOCK_SIZE can be 16/32/64/128/256 +// BLOCK_SIZE can be 16/256 for legacy and xattn cases respectively #define PA_KV_CACHE_BLOCK_SIZE 16 #define PA_KV_CACHE_BLOCK_SIZE_XATTN 256 @@ -41,6 +41,7 @@ constexpr uint32_t SG_N = 8; constexpr uint32_t BLOCK_WG_M = BLOCK_SG_M * SG_M; constexpr uint32_t BLOCK_WG_N = BLOCK_SG_N * SG_N; constexpr int STRIDE = 16; +constexpr size_t XATTN_BLOCK_SIZE = 128; enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; struct PagedAttentionRuntimeParams : public ImplRuntimeParams { @@ -63,8 +64,8 @@ size_t get_partition_size(const bool has_xattention); size_t get_partition_num(const size_t kv_len, const bool has_xattention); const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx = 0); -inline size_t get_xattn_block_size(const kernel_impl_params& impl_param) { - return impl_param.get_program().get_config().get_xattention_block_size(); +inline const size_t get_xattn_block_size(const kernel_impl_params& impl_param) { + return XATTN_BLOCK_SIZE; } class PagedAttentionGeneratorBase : public KernelGenerator { From f9f58beeb88c8c9e03ccb538cb6eed6b820fd636 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Sat, 11 Oct 2025 15:45:34 +0800 Subject: [PATCH 42/96] invalidate sparse atten process if threshold is larger than 1.0. --- .../graph/impls/cm/include/cm_pa_common.hpp | 13 ++++++-- .../src/graph/impls/cm/pa_multi_token.cm | 23 ++++++++------ .../src/graph/impls/cm/paged_attention.cpp | 14 ++++----- .../graph/impls/cm/paged_attention_gen.cpp | 31 +++++++++++-------- 4 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp index 0535013926ab37..59ef0081413dec 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp @@ -32,6 +32,7 @@ void pa_lsc_u8( #if SPARSE_BLOCK_SIZE > 1 svmptr_t sparse_mask_base [[type("svmptr_t")]], svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], + bool validate, #endif svmptr_t o_base [[type("svmptr_t")]], int32_t past_lens, @@ -93,9 +94,11 @@ void pa_lsc_u8( auto load_slm_KV = [&](int kv_pos) { if (kv_pos < kv_stop) { #if SPARSE_BLOCK_SIZE > 1 - if (skip_load(kv_pos)) { - slm_buff_id_write++; - return; + if (validate) { + if (skip_load(kv_pos)) { + slm_buff_id_write++; + return; + } } #endif auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; @@ -201,11 +204,13 @@ void pa_lsc_u8( #if SPARSE_BLOCK_SIZE > 1 + if (validate) { if (skip_compute(kv_pos)) { if constexpr (use_causal_mask) causal_left -= kv_step; continue; } + } #endif { @@ -277,6 +282,7 @@ void pa_kernel_lsc_prefetch_f16( #if SPARSE_BLOCK_SIZE > 1 svmptr_t sparse_mask_base [[type("svmptr_t")]], svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], + bool validate, #endif svmptr_t o_base [[type("svmptr_t")]], int32_t past_lens, @@ -341,6 +347,7 @@ void pa_kernel_lsc_prefetch_f16( cm_prefetch(prefetch_K.set_block_x(0)); #if SPARSE_BLOCK_SIZE > 1 + if (validate) { auto kv_start_block = kv_pos/ SPARSE_BLOCK_SIZE; bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm index d85d4b692a36ea..f9d10d85b4fdf5 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm @@ -45,7 +45,9 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( bool* sparse_block_mask_wg [[type("svmptr_t")]], int q_len, int num_q_blocks, - int num_k_blocks) { + int num_k_blocks, + // validate sparse atten process + bool validate) { #else int q_len) { #endif @@ -125,12 +127,15 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint q_offset = (q_start_sg*num_heads + h)*head_size; #if SPARSE_BLOCK_SIZE > 1 - //# sparse_block_mask [num_heads, num_q_blocks, num_k_blocks] - //# sparse_block_mask_wg [num_heads, wg_count_along_query, num_k_blocks] - auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE; - bool* block_mask_base = sparse_block_mask + (h * num_q_blocks + q_start_block) * num_k_blocks; - bool* wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id) * num_k_blocks; - // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, num_q_blocks, num_k_blocks, sparse_block_mask, block_mask_base); + bool *block_mask_base, *wg_block_mask_base; + if (validate) { + //# sparse_block_mask [num_heads, num_q_blocks, num_k_blocks] + //# sparse_block_mask_wg [num_heads, wg_count_along_query, num_k_blocks] + auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE; + block_mask_base = sparse_block_mask + (h * num_q_blocks + q_start_block) * num_k_blocks; + wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id) * num_k_blocks; + // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, num_q_blocks, num_k_blocks, sparse_block_mask, block_mask_base); + } #endif #if CMPA_KVCACHE_U8 @@ -150,7 +155,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #if SPARSE_BLOCK_SIZE > 1 reinterpret_cast(block_mask_base), reinterpret_cast(wg_block_mask_base), - + validate, #endif reinterpret_cast(output + q_offset), past_q_lens, @@ -169,7 +174,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #if SPARSE_BLOCK_SIZE > 1 reinterpret_cast(block_mask_base), reinterpret_cast(wg_block_mask_base), - + validate, #endif reinterpret_cast(output + q_offset), past_q_lens, diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index c79f6dd468f6d0..e97f7aeeb3b108 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -50,8 +50,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { add_stage(pa_single_token, params); add_stage(pa_single_token_finalization, params); add_stage(pa_multi_token, params); - const size_t xattn_block_size = get_xattn_block_size(params); - if (desc->has_xattention && xattn_block_size > 1) { + if (desc->has_xattention) { add_stage(xattn_estimate_gemmqk, params); add_stage(xattn_estimate_find_block, params); add_stage(xattn_estimate_post_proc, params); @@ -93,8 +92,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { rt_params->partition_size = get_partition_size(desc->has_xattention); rt_params->num_of_partitions = ceil_div(max_context_len, rt_params->partition_size); rt_params->stage = get_paged_attention_stage(params); - const size_t block_size = get_xattn_block_size(params); - if (block_size > 1) { + if (desc->has_xattention) { update_xattn_rt_params(instance); } @@ -121,7 +119,9 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { res_event = {execute_stage(res_event, instance, kv_cache_update)}; if (rt_params->stage == PagedAttentionStage::PREFILL || rt_params->stage == PagedAttentionStage::MIXED) { - if (has_stage(xattn_estimate_gemmqk)) { + const float xattn_thresh = get_xattn_thresh(params); + const bool validate = xattn_thresh < 1.0; + if (has_stage(xattn_estimate_gemmqk) && validate) { // bypass xattn stages if threshold is larger than 1.0. // cldnn::stream& stream = instance.get_network().get_stream(); // stream.finish(); res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; @@ -217,8 +217,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg - const size_t block_size = get_xattn_block_size(params); - if (desc->has_xattention && block_size > 1) { + if (desc->has_xattention) { + const size_t block_size = get_xattn_block_size(params); OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); const uint32_t q_block_pad = ceil_div(q_len, block_size); const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 7ebdb0e1bc1360..435d00f2850232 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -443,16 +443,16 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); - const size_t block_size = get_xattn_block_size(params); - if (desc->has_xattention && block_size > 1) { + if (desc->has_xattention) { args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // sparse_block_mask_wg } args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len - if (block_size > 1) { + if (desc->has_xattention) { args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // validate } return args; } @@ -536,16 +536,21 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } - std::vector scaler_value = {q_len}; - const size_t block_size = get_xattn_block_size(params); - if (block_size > 1) { - scaler_value.push_back(rtp->xattn_q_block_pad); - scaler_value.push_back(rtp->xattn_k_block_pad); - } - scalars.resize(scaler_value.size()); - for (size_t i = 0; i < scaler_value.size(); ++i) { - scalars[i].t = ScalarDescriptor::Types::INT32; - scalars[i].v.s32 = static_cast(scaler_value[i]); + auto num_scalers = desc->has_xattention ? 4 : 1; + scalars.resize(num_scalers); + scalars[0].t = ScalarDescriptor::Types::INT32; + scalars[0].v.s32 = static_cast(q_len); + if (num_scalers > 1) { + scalars[1].t = ScalarDescriptor::Types::INT32; + scalars[1].v.s32 = static_cast(rtp->xattn_q_block_pad); + + scalars[2].t = ScalarDescriptor::Types::INT32; + scalars[2].v.s32 = static_cast(rtp->xattn_k_block_pad); + + scalars[3].t = ScalarDescriptor::Types::UINT8; + const float xattn_thresh = get_xattn_thresh(params); + const bool validate = xattn_thresh < 1.0; + scalars[3].v.u8 = static_cast(validate); // validate depending on xattn_threshold } }}; } From 3afbdb5ba1b8bb80de71a55fecad75ac9b24fbe2 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Sat, 11 Oct 2025 16:54:02 +0800 Subject: [PATCH 43/96] cpplint error fixes --- .../graph/impls/cm/paged_attention_gen.cpp | 19 ++++++++++--------- .../graph/impls/cm/paged_attention_gen.hpp | 3 ++- .../src/plugin/ops/paged_attention.cpp | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 435d00f2850232..c43f96fe03ffce 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -380,9 +380,9 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); - for(size_t i = dims_padding.size() - 1; i > 0; --i) { + for (size_t i = dims_padding.size() - 1; i > 0; --i) { pitch = dims_padding[i]; - if(pitch > 1) { + if (pitch > 1) { break; } } @@ -841,9 +841,9 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); - for(size_t i = dims_padding.size() - 1; i > 0; --i) { + for (size_t i = dims_padding.size() - 1; i > 0; --i) { pitch = dims_padding[i]; - if(pitch > 1) { + if (pitch > 1) { break; } } @@ -874,9 +874,10 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { size_t max_context_len = get_max_context_len(params); size_t past_len = get_past_len(params, 0); std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " - << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad << ", scaler_value: " << PartialShape(scaler_value) << ", kv_len: " << kv_len - << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] - << ", " << wgs.global[2] << "]" + << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad + << ", scaler_value: " << PartialShape(scaler_value) << ", kv_len: " << kv_len + << ", max_context_len = " << max_context_len << ", past_len = " << past_len + << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; dump_block_indices_begins(params); @@ -973,8 +974,8 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " << "xattn_thresh : " << xattn_thresh << " k_block: " << k_block << ", q_block: " << q_block - << " q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " - << wgs.global[1] << ", " << wgs.global[2] << "]" + << " q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad<< ", k_block_pad: " << k_block_pad + << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index f17366cb5a9caa..42c6c1283f79ff 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -116,7 +116,8 @@ class PagedAttentionGeneratorSingleTokenFinalization : public PagedAttentionGene //----------------------------------------------------------------------------------------------------------------- class XAttentionEstimateGeneratorBase : public KernelGenerator { public: - explicit XAttentionEstimateGeneratorBase(std::string_view kernel_name, std::string_view stage_suffix = "_cm") : KernelGenerator(kernel_name, stage_suffix) {} + explicit XAttentionEstimateGeneratorBase(std::string_view kernel_name, std::string_view stage_suffix = "_cm") + : KernelGenerator(kernel_name, stage_suffix) {} [[nodiscard]] std::string get_build_options(const RuntimeParams& params) const override { return KernelGenerator::get_build_options(params) + get_pa_build_options(); } diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 6a668540cf3db9..3f25ee703b2b9d 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -96,7 +96,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared if (xattention_threshold_input && xattention_threshold_input->get_output_partial_shape(0).is_dynamic()) { // TODO: enable xattention_threshold_input prim.has_xattention = true; - } else if(key_cache_ps[3].get_length() == k_head_size && key_cache_ps[2].get_length() == 256) { + } else if (key_cache_ps[3].get_length() == k_head_size && key_cache_ps[2].get_length() == 256) { prim.has_xattention = true; } From 2c37d0d2ee61d151e95c292c691739aa8dd7d600 Mon Sep 17 00:00:00 2001 From: Chen Peter Date: Sun, 12 Oct 2025 12:39:31 +0800 Subject: [PATCH 44/96] Define ENABLE_PA_CM_PATH for build --- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index ccd59cf350fc4d..7947a4e7305e5d 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -220,6 +220,8 @@ struct PagedAttentionManager { return get_QKV_memory(value_data, v_head_size, true); } +/* TODO: These CM kernels test should be run only if CM compiler is ready on the system */ +#define ENABLE_PA_CM_PATH 1 // Define it here to make the build passed #if ENABLE_PA_CM_PATH memory::ptr get_key_cache_memory() { auto key_cache_dt = data_types::f16; From cae516aa81e23206fe4cd2eea7c149f17ad991da Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Sun, 12 Oct 2025 15:13:20 +0800 Subject: [PATCH 45/96] Fix worning as error issues on windows with VS2022 Signed-off-by: Zhai, Xuejun --- .../intel_gpu/src/graph/impls/cm/paged_attention.cpp | 6 +++--- .../src/graph/impls/cm/paged_attention_gen.cpp | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index e97f7aeeb3b108..06998ec69396ad 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -64,7 +64,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const size_t block_size = get_xattn_block_size(params); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint32_t N = kv_len / STRIDE; + const size_t N = kv_len / STRIDE; const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); const auto q_block_pad = ceil_div(q_len, block_size); @@ -212,7 +212,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` const uint32_t N = static_cast(kv_len / STRIDE); const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg @@ -256,4 +256,4 @@ std::unique_ptr PagedAttentionImplementationManager::create_impl } // namespace ov::intel_gpu::cm // BIND_BINARY_BUFFER_WITH_TYPE(cldnn::paged_attention) -BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::cm::PagedAttentionCmImpl) \ No newline at end of file +BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::cm::PagedAttentionCmImpl) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 435d00f2850232..2721e6f924a2f4 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -837,7 +837,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` const uint32_t N = static_cast(kv_len / STRIDE); - const uint32_t K = STRIDE * head_size; + const uint32_t K = static_cast(STRIDE * head_size); auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); @@ -849,8 +849,8 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { } return pitch; }; - const uint32_t query_pitch = get_simple_pitch(querry_layout) * STRIDE; - const uint32_t slice_no = 0, slice = 0; + const size_t query_pitch = get_simple_pitch(querry_layout) * STRIDE; + const size_t slice_no = 0, slice = 0; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); @@ -950,7 +950,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { const uint32_t q_stride = M; const uint32_t k_stride = N; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); @@ -1033,7 +1033,7 @@ DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` const uint32_t N = static_cast(kv_len / STRIDE); const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); @@ -1057,4 +1057,4 @@ DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { }}; } -} // namespace ov::intel_gpu::cm \ No newline at end of file +} // namespace ov::intel_gpu::cm From 808a789d7164cc5d4a8f0849192d64818bc6986c Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Fri, 10 Oct 2025 15:31:29 +0800 Subject: [PATCH 46/96] [WA] clean unused kvcache buffer --- .../graph/impls/cm/pa_kv_cache_update_ref.cm | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm index d24e55d52cabc3..6706987efe50ca 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -23,6 +23,7 @@ #endif constexpr uint wg_size = WG_SIZE; +#define REG_K 16 // extern "C" _GENX_MAIN_ void pa_kv_cache_update( extern "C" _GENX_MAIN_ void KERNEL_NAME( @@ -69,7 +70,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const uint token_idx = cm_global_id(2); // token_idx -> subsequence_idx - if (token_idx >= subsequence_begins[batch_size_in_sequences]) return; + // if (token_idx >= subsequence_begins[batch_size_in_sequences]) return; uint subsequence_idx = 0; for (uint i = 0; i < batch_size_in_sequences; i++) { if (token_idx >= subsequence_begins[i] && token_idx < subsequence_begins[i + 1]) { @@ -81,14 +82,50 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( // printf("wg:%d.%d, token_idx: %d, subsequence_idx: %d\n", wg_id, wg_local_id, token_idx, subsequence_idx); const uint subsequence_begin_idx = subsequence_begins[subsequence_idx]; - const uint past_len = past_lens[subsequence_idx]; - const uint current_block_idx = (past_len + token_idx - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; const uint token_start_pos = (past_len + token_idx - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE; - const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx; + if (token_idx >= subsequence_begins[batch_size_in_sequences]) { + #if KV_CACHE_COMPRESSION_PER_TOKEN + #else + // In PTL some V cache are written with NAN or random value due to unknown reason, while PA kernel will leverage lsc cm_load to + // load V cache by 16x16 block with vnni format, it is hard to exclude the unused V cache when NAN is involved in the same 16x16 block. + // Once NAN takes part in dpas, the NAN will propagate and cause result become NAN. + // As a WA, we need to set the unused part(in the same 16 row) of V cache to 0 here. + const uint last_token_idx = (past_len + 1) % PAGED_ATTENTION_BLOCK_SIZE; + const uint last_token_idx_aligned = (last_token_idx + REG_K - 1) / REG_K * REG_K; + + // if (token_idx >= last_token_idx && token_idx < PAGED_ATTENTION_BLOCK_SIZE) { + if (token_idx >= last_token_idx && token_idx < last_token_idx_aligned) { + uint block_k_base_offset = ((past_len + 1) / PAGED_ATTENTION_BLOCK_SIZE) * KV_HEADS_NUM * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; + uint key_out_offset = block_k_base_offset + head_idx * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + token_idx * ADJUSTED_K_HEAD_SIZE; + vector zero_data = 0; + //if(token_idx == last_token_idx_aligned - 1) { + // zero_data[18] = 0xFE00; //test NAN + //} + + // Only reset unused part in the same 16 row for V cache. + // cm_ptr_store((int*)key_cache, key_out_offset * (int)sizeof(half), zero_data.format()); + cm_ptr_store((int*)value_cache, key_out_offset * (int)sizeof(half), zero_data.format()); + + if(0) { + const uint block_idx = key_out_offset / (ADJUSTED_K_HEAD_SIZE * KV_HEADS_NUM * PAGED_ATTENTION_BLOCK_SIZE); + const uint head_idx = (key_out_offset % (ADJUSTED_K_HEAD_SIZE * KV_HEADS_NUM * PAGED_ATTENTION_BLOCK_SIZE)) / (ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE); + const uint block_m = (key_out_offset % (ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE)) / ADJUSTED_K_HEAD_SIZE; + const uint block_n = (key_out_offset % ADJUSTED_K_HEAD_SIZE); + + if(cm_global_id(0)==0 && cm_global_id(1)==0) + printf("token_idx = %d, last_token_idx = %d, subsequence_begins[%d] = %d, past_len = %d, out_token_idx = %d, key_out_offset = %d, reset_block = [%d, %d,%d,%d]\n", + token_idx, last_token_idx, batch_size_in_sequences, subsequence_begins[batch_size_in_sequences], past_len, + key_out_offset/(ADJUSTED_K_HEAD_SIZE * KV_HEADS_NUM), key_out_offset, block_idx, head_idx, block_m, block_n); + } + } + #endif + return; + } + #if KV_CACHE_COMPRESSION_PER_TOKEN // Assume: K_HEAD_SIZE == K_HEAD_SIZE auto quantize_and_store = [&](vector data, uchar* out, uint out_offset, uint token_pos) { From 22f0459435a91e9069ead787f82385e827356557 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Mon, 13 Oct 2025 10:31:58 +0800 Subject: [PATCH 47/96] Fix format issues Signed-off-by: Zhai, Xuejun --- .../src/graph/impls/cm/paged_attention.cpp | 26 ++++++------- .../graph/impls/cm/paged_attention_gen.cpp | 39 +++++++++---------- .../graph/impls/cm/paged_attention_gen.hpp | 3 +- .../impls/ocl_v2/sdpa/paged_attention_opt.cpp | 6 ++- 4 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 06998ec69396ad..fc3c01ca8af7b0 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -3,26 +3,26 @@ // #include "paged_attention.hpp" -#include "paged_attention_gen.hpp" #include #include #include #include -#include "primitive_cm_base.hpp" -#include "common_utils/kernel_generator_base.hpp" #include "common_utils/jitter.hpp" +#include "common_utils/kernel_generator_base.hpp" #include "intel_gpu/graph/kernel_impl_params.hpp" #include "intel_gpu/primitives/paged_attention.hpp" #include "kv_cache_inst.h" #include "openvino/core/partial_shape.hpp" +#include "paged_attention_gen.hpp" #include "paged_attention_inst.h" +#include "primitive_cm_base.hpp" #include "primitive_inst.h" #define DUMP_XATTN_BLOCK_MASK 0 #if DUMP_XATTN_BLOCK_MASK -#include "openvino/util/file_util.hpp" +# include "openvino/util/file_util.hpp" #endif namespace ov::intel_gpu::cm { @@ -39,7 +39,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { Stage::Ptr xattn_estimate_find_block = make_stage(); Stage::Ptr xattn_estimate_post_proc = make_stage(); - PagedAttentionCmImpl(): PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) { + PagedAttentionCmImpl() : PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) { m_rt_params = std::make_unique(); } explicit PagedAttentionCmImpl(const kernel_impl_params& params) : PagedAttentionCmImpl() { @@ -121,7 +121,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { if (rt_params->stage == PagedAttentionStage::PREFILL || rt_params->stage == PagedAttentionStage::MIXED) { const float xattn_thresh = get_xattn_thresh(params); const bool validate = xattn_thresh < 1.0; - if (has_stage(xattn_estimate_gemmqk) && validate) { // bypass xattn stages if threshold is larger than 1.0. + if (has_stage(xattn_estimate_gemmqk) && validate) { // bypass xattn stages if threshold is larger than 1.0. // cldnn::stream& stream = instance.get_network().get_stream(); // stream.finish(); res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; @@ -141,7 +141,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { std::string format = layout.format.to_string(); std::string tensor; auto dims = layout.get_dims(); - for (size_t r = 0 ; r < layout.get_rank() ; r++) { + for (size_t r = 0; r < layout.get_rank(); r++) { tensor += ("_" + to_string(dims[r])); } // std::string filename = "PA" + std::to_string(pa_id) + "__" + data_type + "_" + tensor + "__" + format + ".bin"; @@ -209,13 +209,13 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto out_shape = params.output_layouts[0].get_shape(); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` const uint32_t N = static_cast(kv_len / STRIDE); const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); - internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg + internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg if (desc->has_xattention) { const size_t block_size = get_xattn_block_size(params); @@ -225,12 +225,12 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); const uint32_t k_block_pad = k_block_in_group * N_kq_groups; auto count_kq_exp_partial_sum = static_cast(desc->heads_num * q_stride_pad * k_block_pad); - internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum + internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum auto count_elements_mask = static_cast(desc->heads_num * q_block_pad * k_block_pad); - internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask + internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask - const uint32_t MERGED_Q_NUM = 2; // TODO + const uint32_t MERGED_Q_NUM = 2; // TODO const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); auto count_elements_mask_merged = static_cast(desc->heads_num * q_block_pad_merged * k_block_pad); internal_buffers.emplace_back(count_elements_mask_merged, ov::element::boolean); // 5: sparse_block_mask_wg @@ -254,6 +254,6 @@ std::unique_ptr PagedAttentionImplementationManager::create_impl } } -} // namespace ov::intel_gpu::cm +} // namespace ov::intel_gpu::cm // BIND_BINARY_BUFFER_WITH_TYPE(cldnn::paged_attention) BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::cm::PagedAttentionCmImpl) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index cbc047921e4d56..967b87a840bddf 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -156,7 +156,8 @@ int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAtt if (stage == PagedAttentionStage::PREFILL) { const auto desc = impl_param.typed_desc(); int64_t pa_block_size = paged_attention::block_size; - if (desc->has_xattention) pa_block_size = paged_attention::block_size_xattn; + if (desc->has_xattention) + pa_block_size = paged_attention::block_size_xattn; if (static_cast(pa_block_size) == target_seq_len_block_size) { const auto& block_indices_ps = impl_param.get_input_layout(PagedAttentionInputIdx::BLOCK_INDICES).get_partial_shape(); @@ -213,7 +214,7 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx) { const float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_idx) { const auto& input_mem = params.memory_deps; const auto threshold_mem = input_mem.at(PagedAttentionInputIdx::XATTENTION_THRESHOLD); - mem_lock lock(threshold_mem, *params.strm); // converted + mem_lock lock(threshold_mem, *params.strm); // converted const auto thresh = static_cast(lock[seq_idx]); return thresh; } @@ -448,11 +449,11 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // sparse_block_mask_wg } - args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len if (desc->has_xattention) { - args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad - args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad - args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // validate + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // validate } return args; } @@ -715,12 +716,10 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da }}; } - //----------------------------------------------------------------------------------------------------------------- // Helpers of XAttention //----------------------------------------------------------------------------------------------------------------- - //----------------------------------------------------------------------------------------------------------------- // Base generator of XAttention //----------------------------------------------------------------------------------------------------------------- @@ -835,7 +834,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { auto out_shape = params.output_layouts[0].get_shape(); const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` const uint32_t N = static_cast(kv_len / STRIDE); const uint32_t K = static_cast(STRIDE * head_size); auto get_simple_pitch = [](const layout& layout) { @@ -874,10 +873,9 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { size_t max_context_len = get_max_context_len(params); size_t past_len = get_past_len(params, 0); std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " - << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad - << ", scaler_value: " << PartialShape(scaler_value) << ", kv_len: " << kv_len - << ", max_context_len = " << max_context_len << ", past_len = " << past_len - << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" + << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad << ", scaler_value: " << PartialShape(scaler_value) + << ", kv_len: " << kv_len << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] + << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; dump_block_indices_begins(params); @@ -885,7 +883,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { } for (size_t i = 0; i < scaler_value.size(); ++i) { - if (i == 4 || i == 5) { + if (i == 4 || i == 5) { scalars[i].t = ScalarDescriptor::Types::INT32; scalars[i].v.s32 = static_cast(scaler_value[i]); } else { @@ -946,7 +944,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { auto out_shape = params.output_layouts[0].get_shape(); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` const uint32_t N = static_cast(kv_len / STRIDE); const uint32_t q_stride = M; const uint32_t k_stride = N; @@ -972,10 +970,9 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { if (DEBUG_ENABLED) { // Debug std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " - << "xattn_thresh : " << xattn_thresh - << " k_block: " << k_block << ", q_block: " << q_block - << " q_stride: " << q_stride << ", q_stride_pad: " << q_stride_pad<< ", k_block_pad: " << k_block_pad - << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" + << "xattn_thresh : " << xattn_thresh << " k_block: " << k_block << ", q_block: " << q_block << " q_stride: " << q_stride + << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " + << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } @@ -1031,7 +1028,7 @@ DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { auto out_shape = params.output_layouts[0].get_shape(); const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` const uint32_t N = static_cast(kv_len / STRIDE); const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); @@ -1041,7 +1038,7 @@ DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { const uint32_t k_block_pad = k_block_in_group * N_kq_groups; const uint32_t q_block_pad = ceil_div(q_len, block_size); - const uint32_t MERGED_Q_NUM = 2; // TODO + const uint32_t MERGED_Q_NUM = 2; // TODO const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); wgs.global = {q_block_pad_merged, heads_num, 1}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 42c6c1283f79ff..337c974dbe7156 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -31,7 +31,7 @@ constexpr auto get_pa_build_options() { } // BLOCK_SIZE can be 16/256 for legacy and xattn cases respectively -#define PA_KV_CACHE_BLOCK_SIZE 16 +#define PA_KV_CACHE_BLOCK_SIZE 16 #define PA_KV_CACHE_BLOCK_SIZE_XATTN 256 constexpr uint32_t BLOCK_SG_M = 64; @@ -54,7 +54,6 @@ struct PagedAttentionRuntimeParams : public ImplRuntimeParams { size_t xattn_k_block_pad; }; - //----------------------------------------------------------------------------------------------------------------- // Helpers of XAttention //----------------------------------------------------------------------------------------------------------------- diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp index 62ebcbece1d81c..464087f9557f47 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp @@ -196,7 +196,8 @@ static int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const P if (stage == PagedAttentionStage::PREFILL) { const auto desc = impl_param.typed_desc(); int64_t pa_block_size = paged_attention::block_size; - if (desc->has_xattention) pa_block_size = paged_attention::block_size_xattn; + if (desc->has_xattention) + pa_block_size = paged_attention::block_size_xattn; if (static_cast(pa_block_size) == target_seq_len_block_size) { const auto& block_indices_ps = impl_param.get_input_layout(PagedAttentionInputIdx::BLOCK_INDICES).get_partial_shape(); @@ -1579,7 +1580,8 @@ class PagedAttentionOptImpl : public SDPAImplBase { size_t micro_sdpa_index = 0; size_t subsequence_offsets_acc = 0; int pa_block_size = paged_attention::block_size; - if (desc->has_xattention) pa_block_size = paged_attention::block_size_xattn; + if (desc->has_xattention) + pa_block_size = paged_attention::block_size_xattn; for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { const auto past_len = past_lens_mem_lock[i]; const auto seq_start = subsequence_begins_mem_lock[i]; From 780f55a590d598e709c2df1a1ea569d499355b80 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 13 Oct 2025 10:55:26 +0800 Subject: [PATCH 48/96] disable XAttention for legacy platforms (XAttention kernels are implemented for Xe2/Xe3 with CM) --- .../src/plugin/ops/paged_attention.cpp | 9 +- .../src/plugin/transformations_pipeline.cpp | 126 +++++++++++------- 2 files changed, 84 insertions(+), 51 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 3f25ee703b2b9d..0c8fdf5432884b 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -91,12 +91,9 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared prim.has_rotated_blocks = true; } - const size_t xattention_threshold_idx = cldnn::paged_attention::PagedAttentionInputIdx::XATTENTION_THRESHOLD; - auto xattention_threshold_input = ov::as_type_ptr(op->get_input_node_shared_ptr(xattention_threshold_idx)); - if (xattention_threshold_input && xattention_threshold_input->get_output_partial_shape(0).is_dynamic()) { - // TODO: enable xattention_threshold_input - prim.has_xattention = true; - } else if (key_cache_ps[3].get_length() == k_head_size && key_cache_ps[2].get_length() == 256) { + // We may fallback to dense attention mode if xattn is not supported by either GPU archieture or compiler. + // So we check key cache shape, instead of checking op inputs to determine if xatnn is enabled. + if (key_cache_ps[3].get_length() == k_head_size && key_cache_ps[2].get_length() == 256) { prim.has_xattention = true; } diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 0e66bf84be4df5..f0788cd3abb075 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -508,54 +508,90 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // To handle this case, "KeepConstPrecision" is executed again. manager.register_pass(supported_woq_types, !device_info.supports_immad); - bool use_xattention = false; - const auto& parameters = func->get_parameters(); - for (const auto& param : parameters) { - if (param->get_friendly_name() == "xattention_block_size") { - use_xattention = true; + { + // Disable XAttention if GPU Xe2/Xe3 architectures is unavaiable or IGC incompatiable. + auto check_xattn_gpu_compatibility = [&](void) -> bool { + auto& engine = m_context->get_engine(); + const auto& info = engine.get_device_info(); + if (info.arch != cldnn::gpu_arch::xe2 && info.arch != cldnn::gpu_arch::xe3) { // CM optimized for systolic-array architectures + return false; + } + +#ifdef GPU_DEBUG_CONFIG + if (!config.get_use_cm()) { + OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," + "as CM for usage is disabled. Enable it by setting environment variable OV_GPU_USE_CM=ON."); + return false; + } +#endif + + if (!check_cm_jit_support(engine, config)) { + OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," + "as current IGC version is not compatible to the CM kernel used. Enable it by update IGC." + "Please also make sure clangFEWrapper for CM is present by checking environment varibles like " + "CM_FE_DIR or LD_LIBRARY_PATH if you are using Linux."); + return false; + } + + return true; + }; + + // Determine if XAttention is enabled by user (via GENAI) by checking if model parameters contains + // xattention configurations, which are added in SDPAToPagedAttention pass. + bool use_xattention = false; + const auto& parameters = func->get_parameters(); + for (const auto& param : parameters) { + if (param->get_friendly_name() == "xattention_block_size") { + use_xattention = true; + break; + } } - } - ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; - kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); - kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); - kv_cache_config.inferencePrecision = infer_precision; - if (use_xattention) { - kv_cache_config.keyCacheBlockSize = 256; - kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; - } else { - kv_cache_config.keyCacheBlockSize = 16; - kv_cache_config.keyCacheDimOrder = {0, 1, 3, 2}; - } - kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); - kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; - if (use_xattention) { - kv_cache_config.valueCacheBlockSize = 256; - kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; - } else { - kv_cache_config.valueCacheBlockSize = 16; - kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; - } - kv_cache_config.valueCacheQuantBychannel = false; - kv_cache_config.valueCacheGroupSize = 0; - - manager.register_pass(kv_cache_config, - [&infer_precision](const ov::element::Type& precision, - const bool bychannel, - const size_t group_num, - int64_t& head_size, - int64_t& block_size) { - if (bychannel) { - // TODO: need to handle group size != block size case - if (precision == ov::element::i8 || precision == ov::element::u8) { - block_size += infer_precision.size() * 2; - } - } else { - if (precision == ov::element::i8 || precision == ov::element::u8) { - head_size += infer_precision.size() * 2 * group_num; + // Fallback to dense attention if xattn is not supported by either GPU archieture or compiler. + if (use_xattention) + use_xattention = check_xattn_gpu_compatibility(); + + ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; + kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); + kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); + kv_cache_config.inferencePrecision = infer_precision; + if (use_xattention) { + kv_cache_config.keyCacheBlockSize = 256; + kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; // default dim order of [num_blocks, num_kv_heads, block_size, head_size] + } else { + kv_cache_config.keyCacheBlockSize = 16; + kv_cache_config.keyCacheDimOrder = {0, 1, 3, 2}; + } + kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); + kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; + if (use_xattention) { + kv_cache_config.valueCacheBlockSize = 256; + kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; + } else { + kv_cache_config.valueCacheBlockSize = 16; + kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; + } + kv_cache_config.valueCacheQuantBychannel = false; + kv_cache_config.valueCacheGroupSize = 0; + + manager.register_pass(kv_cache_config, + [&infer_precision](const ov::element::Type& precision, + const bool bychannel, + const size_t group_num, + int64_t& head_size, + int64_t& block_size) { + if (bychannel) { + // TODO: need to handle group size != block size case + if (precision == ov::element::i8 || precision == ov::element::u8) { + block_size += infer_precision.size() * 2; + } + } else { + if (precision == ov::element::i8 || precision == ov::element::u8) { + head_size += infer_precision.size() * 2 * group_num; + } } - } - }); + }); + } pass_config->set_callback([&](const std::shared_ptr node){ if (!config.get_enable_sdpa_optimization()) From d21c4f6b1e8f48e3cbb6ec6192a958f5fdc59817 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Mon, 13 Oct 2025 11:09:57 +0800 Subject: [PATCH 49/96] reset left V cache block rather than 16 rows --- .../intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm | 6 +++--- src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm | 2 +- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm index 6706987efe50ca..b85fea9dc1397d 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -95,10 +95,10 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( // Once NAN takes part in dpas, the NAN will propagate and cause result become NAN. // As a WA, we need to set the unused part(in the same 16 row) of V cache to 0 here. const uint last_token_idx = (past_len + 1) % PAGED_ATTENTION_BLOCK_SIZE; - const uint last_token_idx_aligned = (last_token_idx + REG_K - 1) / REG_K * REG_K; + // const uint last_token_idx_aligned = (last_token_idx + REG_K - 1) / REG_K * REG_K; - // if (token_idx >= last_token_idx && token_idx < PAGED_ATTENTION_BLOCK_SIZE) { - if (token_idx >= last_token_idx && token_idx < last_token_idx_aligned) { + if (token_idx >= last_token_idx && token_idx < PAGED_ATTENTION_BLOCK_SIZE) { + // if (token_idx >= last_token_idx && token_idx < last_token_idx_aligned) { uint block_k_base_offset = ((past_len + 1) / PAGED_ATTENTION_BLOCK_SIZE) * KV_HEADS_NUM * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; uint key_out_offset = block_k_base_offset + head_idx * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + token_idx * ADJUSTED_K_HEAD_SIZE; vector zero_data = 0; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index 0537797766bb66..f25f571025b3f0 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -143,7 +143,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #if USE_LSC_BLOCK_2D_DESC #if KV_CACHE_COMPRESSION - // Transpose only support dword and qwork + // Transpose only support dword and qword lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); #else lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 20c2aee0622b32..b5b4bcd82a1590 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -332,9 +332,9 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() // const size_t kv_len = get_max_context_len(params); const size_t kv_len = get_input_kv_len(params); const size_t kv_heads_num = desc->kv_heads_num; - const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; + const size_t wg_count = (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE; - wgs.global = {1, kv_heads_num, wg_count * WG_SIZE}; + wgs.global = {1, kv_heads_num, wg_count * PA_KV_CACHE_BLOCK_SIZE}; wgs.local = {1, 1, WG_SIZE}; auto& scalars = kd.params.scalars; From 6b9b4c25827de61cea8c0b61bdcecb4948d830dc Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Mon, 13 Oct 2025 11:18:13 +0800 Subject: [PATCH 50/96] Remove debug code --- .../graph/impls/cm/pa_kv_cache_update_ref.cm | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm index b85fea9dc1397d..7016dcddcc6bb9 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -95,32 +95,15 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( // Once NAN takes part in dpas, the NAN will propagate and cause result become NAN. // As a WA, we need to set the unused part(in the same 16 row) of V cache to 0 here. const uint last_token_idx = (past_len + 1) % PAGED_ATTENTION_BLOCK_SIZE; - // const uint last_token_idx_aligned = (last_token_idx + REG_K - 1) / REG_K * REG_K; if (token_idx >= last_token_idx && token_idx < PAGED_ATTENTION_BLOCK_SIZE) { - // if (token_idx >= last_token_idx && token_idx < last_token_idx_aligned) { uint block_k_base_offset = ((past_len + 1) / PAGED_ATTENTION_BLOCK_SIZE) * KV_HEADS_NUM * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; uint key_out_offset = block_k_base_offset + head_idx * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + token_idx * ADJUSTED_K_HEAD_SIZE; vector zero_data = 0; - //if(token_idx == last_token_idx_aligned - 1) { - // zero_data[18] = 0xFE00; //test NAN - //} // Only reset unused part in the same 16 row for V cache. // cm_ptr_store((int*)key_cache, key_out_offset * (int)sizeof(half), zero_data.format()); cm_ptr_store((int*)value_cache, key_out_offset * (int)sizeof(half), zero_data.format()); - - if(0) { - const uint block_idx = key_out_offset / (ADJUSTED_K_HEAD_SIZE * KV_HEADS_NUM * PAGED_ATTENTION_BLOCK_SIZE); - const uint head_idx = (key_out_offset % (ADJUSTED_K_HEAD_SIZE * KV_HEADS_NUM * PAGED_ATTENTION_BLOCK_SIZE)) / (ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE); - const uint block_m = (key_out_offset % (ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE)) / ADJUSTED_K_HEAD_SIZE; - const uint block_n = (key_out_offset % ADJUSTED_K_HEAD_SIZE); - - if(cm_global_id(0)==0 && cm_global_id(1)==0) - printf("token_idx = %d, last_token_idx = %d, subsequence_begins[%d] = %d, past_len = %d, out_token_idx = %d, key_out_offset = %d, reset_block = [%d, %d,%d,%d]\n", - token_idx, last_token_idx, batch_size_in_sequences, subsequence_begins[batch_size_in_sequences], past_len, - key_out_offset/(ADJUSTED_K_HEAD_SIZE * KV_HEADS_NUM), key_out_offset, block_idx, head_idx, block_m, block_n); - } } #endif return; From eb9765e3961566388f6f11bfcf02d6d59eb8b1c2 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 13 Oct 2025 11:30:51 +0800 Subject: [PATCH 51/96] revert code change to ocl_v2 --- .../src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp index 464087f9557f47..200c9a31e00398 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp @@ -195,10 +195,7 @@ static int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const P int64_t aligned_seq_len = 0; if (stage == PagedAttentionStage::PREFILL) { const auto desc = impl_param.typed_desc(); - int64_t pa_block_size = paged_attention::block_size; - if (desc->has_xattention) - pa_block_size = paged_attention::block_size_xattn; - if (static_cast(pa_block_size) == target_seq_len_block_size) { + if (static_cast(paged_attention::block_size) == target_seq_len_block_size) { const auto& block_indices_ps = impl_param.get_input_layout(PagedAttentionInputIdx::BLOCK_INDICES).get_partial_shape(); aligned_seq_len = block_indices_ps[0].get_length() * target_seq_len_block_size; @@ -1579,9 +1576,7 @@ class PagedAttentionOptImpl : public SDPAImplBase { size_t index = 0; size_t micro_sdpa_index = 0; size_t subsequence_offsets_acc = 0; - int pa_block_size = paged_attention::block_size; - if (desc->has_xattention) - pa_block_size = paged_attention::block_size_xattn; + const auto pa_block_size = static_cast(paged_attention::block_size); for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { const auto past_len = past_lens_mem_lock[i]; const auto seq_start = subsequence_begins_mem_lock[i]; From 1418daa21f0fcc8bdcfdd1a6d76c7ea1d6052b7b Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 13 Oct 2025 11:31:09 +0800 Subject: [PATCH 52/96] cleanup debug code --- .../src/graph/impls/cm/paged_attention.cpp | 31 ------------------- .../graph/impls/cm/paged_attention_gen.cpp | 23 -------------- .../src/graph/include/paged_attention_inst.h | 4 --- 3 files changed, 58 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index fc3c01ca8af7b0..8b00651285825c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -20,11 +20,6 @@ #include "primitive_cm_base.hpp" #include "primitive_inst.h" -#define DUMP_XATTN_BLOCK_MASK 0 -#if DUMP_XATTN_BLOCK_MASK -# include "openvino/util/file_util.hpp" -#endif - namespace ov::intel_gpu::cm { class PagedAttentionCmImpl : public PrimitiveImplCM { @@ -122,34 +117,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const float xattn_thresh = get_xattn_thresh(params); const bool validate = xattn_thresh < 1.0; if (has_stage(xattn_estimate_gemmqk) && validate) { // bypass xattn stages if threshold is larger than 1.0. - // cldnn::stream& stream = instance.get_network().get_stream(); - // stream.finish(); res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; - // stream.finish(); - // std::cout << "finish xattn_estimate_gemmqk!\n"; res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; -#if DUMP_XATTN_BLOCK_MASK - { - cldnn::stream& stream = instance.get_network().get_stream(); - stream.finish(); - static uint32_t pa_id = 0; - std::cout << "finish xattn_estimate_find_block!\n"; - auto output_mem = instance.get_intermediates_memories()[4]; - mem_lock lock(output_mem, stream); - auto& layout = output_mem->get_layout(); - std::string data_type = ov::element::Type(layout.data_type).get_type_name(); - std::string format = layout.format.to_string(); - std::string tensor; - auto dims = layout.get_dims(); - for (size_t r = 0; r < layout.get_rank(); r++) { - tensor += ("_" + to_string(dims[r])); - } - // std::string filename = "PA" + std::to_string(pa_id) + "__" + data_type + "_" + tensor + "__" + format + ".bin"; - std::string filename = "PA" + std::to_string(pa_id) + ".bin"; - ov::util::save_binary(filename, lock.data(), output_mem->size()); - pa_id++; - } -#endif res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; } res_event = {execute_stage(res_event, instance, pa_multi_token)}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 967b87a840bddf..397b162a99a2f3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -219,26 +219,6 @@ const float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_ return thresh; } -inline void dump_block_indices_begins(const kernel_impl_params& params) { - const auto& input_mem = params.memory_deps; - const auto mem = input_mem.at(PagedAttentionInputIdx::BLOCK_INDICES_BEGINS); - mem_lock mem_lock(mem, *params.strm); - std::cout << "============ dump BLOCK_INDICES_BEGINS ["; - for (size_t i = 0; i < mem->count(); i++) - std::cout << mem_lock[i] << ", "; - std::cout << "]" << std::endl; -} - -inline void dump_block_indices(const kernel_impl_params& params) { - const auto& input_mem = params.memory_deps; - const auto mem = input_mem.at(PagedAttentionInputIdx::BLOCK_INDICES); - mem_lock mem_lock(mem, *params.strm); - std::cout << "============ dump BLOCK_INDICES ["; - for (size_t i = 0; i < mem->count(); i++) - std::cout << mem_lock[i] << ", "; - std::cout << "]" << std::endl; -} - PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) { const auto& query_shape = impl_param.get_input_layout(PagedAttentionInputIdx::QUERY).get_partial_shape(); const auto& past_lens_shape = impl_param.get_input_layout(PagedAttentionInputIdx::PAST_LENS).get_partial_shape(); @@ -877,9 +857,6 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { << ", kv_len: " << kv_len << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; - - dump_block_indices_begins(params); - dump_block_indices(params); } for (size_t i = 0; i < scaler_value.size(); ++i) { diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index 2eecc48cb71015..dbe8a27f0a79c6 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -25,10 +25,6 @@ struct typed_program_node : public typed_program_node_basehas_score_aggregation) input_ports.insert(PagedAttentionInputIdx::SCORE_AGGREGATION); From 21c3193151152a05944a92fb669850cdb745a282 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Mon, 13 Oct 2025 14:36:16 +0800 Subject: [PATCH 53/96] Limit head_num/kv_head_num not excceed 8 --- .../src/graph/impls/cm/paged_attention.hpp | 7 + .../graph/impls/cm/paged_attention_gen.cpp | 153 ------------------ 2 files changed, 7 insertions(+), 153 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp index 85cecbbfe39dc7..1f3ff893011884 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp @@ -37,6 +37,13 @@ struct PagedAttentionImplementationManager : public ImplementationManager { return false; } + // TODO: Remove this limitation when PA CM kernel supports more "heads_num / kv_heads_num" cases. + // PA 2nd token CM kernel only supports case of "heads_num / kv_heads_num <= 8" + if (desc->heads_num / desc->kv_heads_num > 8) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false because heads_num / kv_heads_num > 8. " << std::endl; + return false; + } + auto& engine = node.get_program().get_engine(); const auto& config = node.get_program().get_config(); const auto& info = engine.get_device_info(); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 397b162a99a2f3..84c58e595dcff6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -72,11 +72,6 @@ inline size_t get_input_kv_len(const RuntimeParams& params) { return kv_len; } -inline size_t get_aligned_kv_len(const size_t kv_len) { - // TODO: how to change PA_KV_CACHE_BLOCK_SIZE here - return (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE * PA_KV_CACHE_BLOCK_SIZE; -} - inline bool get_kv_compressed(const RuntimeParams& params) { auto key_cache_layout = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE]; if (data_type_traits::is_i8_u8(key_cache_layout.data_type)) { @@ -86,98 +81,7 @@ inline bool get_kv_compressed(const RuntimeParams& params) { } } -int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size = 16) { - // Since at prefill stage Q, K, V inputs may contain multiple sequences with arbitrary - // target sequence lengths each (shape is [sequences_num * target_seq_len, num_heads * head_size]), - // to apply blocking to the first dimension (target_seq_len of each sequence), we need to calculate aligned total - // target sequence length for proper kernel dispatching - // For instance, if input contains two sequences with 35 and 28 sequence lengths each, - // the Q, K, V inputs at prefill stage will have shapes [35 + 28, num_heads * head_size]; considering kernel's - // target_seq_len_block_size equals 16, we need to launch kernel instances for the following ranges: - // [0, 15], [16, 31], [32, 34], [35, 50], [51, 62], so aligned target_seq_len_block_size should be 5 * 16 = 80, - // and 5 kernels instances should be launched (for each range, some of them containing leftovers) - // - // In general, to obtain length for each sequence, we have to parse subsequence_begins input, - // which contains begin and end indexes for each sequence (for above example it will contain three values: {0, 35, 63}) - // However, as long as kernel's target_seq_len_block_size matches with vLLM's block_size, - // we can reuse block_indices_shape[0] size to determine total aligned sequences length size, avoiding - // memory access at runtime, because vLLM internally uses similar logic to configure blocks for KV cache - - auto calculate_aligned_seq_len = [&]() { - const auto& input_mem = impl_param.memory_deps; - const auto subsequence_begins_mem = input_mem.at(PagedAttentionInputIdx::SUBSEQUENCE_BEGINS); - mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, *impl_param.strm); - - auto aligned_seq_len = 0; - if (stage == PagedAttentionStage::MIXED) { - const auto past_lens_mem = input_mem.at(PagedAttentionInputIdx::PAST_LENS); - mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); - - for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { - auto past_len = past_lens_mem_lock[i]; - auto seq_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; - - // Since in MIXED execution mode the present KV-cache can be appended to the past KV-cache at any offset inside block, - // to ensure proper alignment and update_kv_cache kernel scheduling, we need to account for the number of unaligned tokens - // in the first block - // For example, if we need to store values in the following slots: - // - // block0: |O|O|O|O|O|O|O|O|O|O|O|O|U|U|U|U| - // block1: |U|U|U|U|U|U|U|U|U|U|U|U|U|U|U|U| - // block2: |U|U|U|U|U|U|E|E|E|E|E|E|E|E|E|E| - // Where O - occupied slots, U - currently beeing updated slots, E - empty slots - // - // We need to schedule 3 update_kv_cache operations: - // - For ranges of block0: [12-15] - // - For ranges of block1: [0-15] - // - For ranges of block2: [0-5] - // - // Therefore, consider an additional increment of aligned_seq_len to properly process all the blocks - - auto occupied_slots_num = past_len % target_seq_len_block_size; - if (past_len != 0 && seq_length + occupied_slots_num > target_seq_len_block_size) { - aligned_seq_len += target_seq_len_block_size; - seq_length -= target_seq_len_block_size - occupied_slots_num; - } - - aligned_seq_len += align_to(seq_length, target_seq_len_block_size); - } - } else { - for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { - auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; - aligned_seq_len += align_to(prompt_length, target_seq_len_block_size); - } - } - - return aligned_seq_len; - }; - - int64_t aligned_seq_len = 0; - if (stage == PagedAttentionStage::PREFILL) { - const auto desc = impl_param.typed_desc(); - int64_t pa_block_size = paged_attention::block_size; - if (desc->has_xattention) - pa_block_size = paged_attention::block_size_xattn; - if (static_cast(pa_block_size) == target_seq_len_block_size) { - const auto& block_indices_ps = impl_param.get_input_layout(PagedAttentionInputIdx::BLOCK_INDICES).get_partial_shape(); - - aligned_seq_len = block_indices_ps[0].get_length() * target_seq_len_block_size; - } else { - aligned_seq_len = calculate_aligned_seq_len(); - } - } else { - aligned_seq_len = calculate_aligned_seq_len(); - } - - return aligned_seq_len; -} - size_t get_partition_size(const bool has_xattention) { - // size_t k_partition_blok_num = (kv_len + 8191) / 8192; - // if (k_partition_blok_num < 1) - // k_partition_blok_num = 1; - // const size_t k_partition_blok_num = 16; - // return k_partition_blok_num * PA_KV_CACHE_BLOCK_SIZE; // 128 if (!has_xattention && PA_KV_CACHE_BLOCK_SIZE < 128) { return 128; } else { @@ -249,9 +153,7 @@ PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_par JitConstants PagedAttentionGeneratorBase::get_jit_constants(const kernel_impl_params& params) const { auto jit = KernelGenerator::get_jit_constants(params); jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); - // std::cout << "PagedAttentionGeneratorBase::get_jit_constants: " << get_entry_point(params) << std::endl; - // auto desc = params.typed_desc(); auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; jit.make("XE_ARCH", xe_arch); @@ -317,9 +219,7 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() assert(!params.is_dynamic()); auto& wgs = kd.params.workGroups; const auto desc = params.typed_desc(); - // auto rtp = static_cast(rt_params); - // const size_t kv_len = get_max_context_len(params); const size_t kv_len = get_input_kv_len(params); const size_t kv_heads_num = desc->kv_heads_num; const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; @@ -333,31 +233,6 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; auto value_layout = params.input_layouts[PagedAttentionInputIdx::VALUE]; - if (0) { // Debug - std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " - << "key_layout: " << key_layout.to_string() << ", value_layout: " << value_layout.to_string() << std::endl; - std::cout << "\tkey_dims = ["; - for (auto& it : key_layout.get_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tkey_pads = ["; - for (auto& it : key_layout.get_padded_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tvalue_dims = ["; - for (auto& it : value_layout.get_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tvalue_pads = ["; - for (auto& it : value_layout.get_padded_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - } - auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); @@ -393,7 +268,6 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() << ", value_offset: " << value_offset << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } - // TODO: support multiple sequences size_t batch_size_in_sequences = 1; std::vector scaler_value = {key_pitch, key_offset, value_pitch, value_offset, batch_size_in_sequences}; @@ -465,10 +339,6 @@ JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_i } else { jit.make("CMPA_KVCACHE_U8", 0); } - // for (auto& it : jit) { - // std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl; - // } - // std::cout << std::endl; return jit; } @@ -480,23 +350,8 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con auto rtp = static_cast(rt_params); // assert(rt_params != nullptr); const size_t heads_num = desc->heads_num; - auto query_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; - if (0 && DEBUG_ENABLED) { // Debug - std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: query_layout: " << query_layout.to_string() << std::endl; - std::cout << "\tquery_dims = ["; - for (auto& it : query_layout.get_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tquery_pads = ["; - for (auto& it : query_layout.get_padded_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - } - auto out_shape = params.output_layouts[0].get_shape(); const size_t batch = out_shape.size() < 4 ? 1 : out_shape[0]; const size_t q_len = out_shape[0]; @@ -516,7 +371,6 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } - auto num_scalers = desc->has_xattention ? 4 : 1; scalars.resize(num_scalers); scalars[0].t = ScalarDescriptor::Types::INT32; @@ -625,7 +479,6 @@ DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() co << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } - for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::INT32; scalars[i].v.s32 = static_cast(scaler_value[i]); @@ -688,7 +541,6 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } - for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::INT32; scalars[i].v.s32 = static_cast(scaler_value[i]); @@ -737,11 +589,6 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp } jit.make("SOFTMAX_TYPE", "float"); - // for (auto& it : jit) { - // std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl; - // } - // std::cout << std::endl; - return jit; } From 8a7a38083393f2feeb2a30155a96b87af124cd91 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 13 Oct 2025 14:48:36 +0800 Subject: [PATCH 54/96] streamline block_size head_size in both cases of fp16 and u8/i8 kvcache --- .../intel_gpu/src/graph/paged_attention.cpp | 14 ++++++-------- .../intel_gpu/src/plugin/ops/paged_attention.cpp | 15 ++++++++------- .../src/plugin/transformations_pipeline.cpp | 6 ++++++ 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 1f43ec08de16a6..f38859686b1e03 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -35,26 +35,24 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no data_layout.data_padding = padding(); - size_t key_cache_idx = cldnn::paged_attention::PagedAttentionInputIdx::KEY_CACHE; + const auto key_cache_idx = cldnn::paged_attention::PagedAttentionInputIdx::KEY_CACHE; const auto& key_cache_ps = impl_param.get_input_layout(key_cache_idx).get_partial_shape(); const auto& key_cache_quant_mode = impl_param.get_program().get_config().get_key_cache_quant_mode(); bool key_cache_compressed = impl_param.get_input_layout(key_cache_idx).data_type == ov::element::i8 || impl_param.get_input_layout(key_cache_idx).data_type == ov::element::u8; - size_t expected_block_size = paged_attention::block_size; - if (desc->has_xattention) { - expected_block_size = paged_attention::block_size_xattn; - key_cache_idx -= 1; - } + auto expected_block_size = desc->has_xattention ? paged_attention::block_size_xattn : paged_attention::block_size; if (key_cache_compressed && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { expected_block_size += 4; } OPENVINO_ASSERT((key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) == desc->is_key_by_channel, "[GPU] Paged Attention key cache quantization mode mismatch: prim.is_key_by_channel : ", desc->is_key_by_channel, " but exec_config : ", impl_param.get_program().get_config().get_key_cache_quant_mode()); + + const auto block_size_idx = desc->has_xattention ? 2 : 3; bool valid_block_size = key_cache_ps.is_dynamic() || - (key_cache_ps[key_cache_idx].get_length() == static_cast(expected_block_size)); + (key_cache_ps[block_size_idx].get_length() == static_cast(expected_block_size)); OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation for key cache quant mode " - , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[key_cache_idx].get_length()); + , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[block_size_idx].get_length()); std::vector output_layouts{ data_layout }; if (desc->has_scores_output()) { diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 0c8fdf5432884b..3f3c628b02c2e1 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -37,7 +37,14 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared auto key_cache_ps = op->get_input_partial_shape(3); auto value_cache_ps = op->get_input_partial_shape(4); - auto k_head_size = has_rt_params ? rt_info.at(k_head_size_id).as() : key_cache_ps[2].get_length(); + // We may fallback to dense attention mode if xattn is not supported by either GPU archieture or compiler. + // So we check block_size from value cache shape, instead of checking op input type, to determine if xatnn is enabled. + if (value_cache_ps[2].get_length() == 256) { + prim.has_xattention = true; + } + const auto k_head_size_idx = prim.has_xattention ? 3 : 2; + + auto k_head_size = has_rt_params ? rt_info.at(k_head_size_id).as() : key_cache_ps[k_head_size_idx].get_length(); auto v_head_size = has_rt_params ? rt_info.at(v_head_size_id).as() : value_cache_ps[3].get_length(); auto kv_heads_num = has_rt_params ? rt_info.at(num_k_heads_id).as() : key_cache_ps[1].get_length(); @@ -91,12 +98,6 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared prim.has_rotated_blocks = true; } - // We may fallback to dense attention mode if xattn is not supported by either GPU archieture or compiler. - // So we check key cache shape, instead of checking op inputs to determine if xatnn is enabled. - if (key_cache_ps[3].get_length() == k_head_size && key_cache_ps[2].get_length() == 256) { - prim.has_xattention = true; - } - const size_t sinks_idx = cldnn::paged_attention::PagedAttentionInputIdx::SINKS; auto sinks_const = ov::as_type_ptr(op->get_input_node_shared_ptr(sinks_idx)); OPENVINO_ASSERT(sinks_const != nullptr); diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index f0788cd3abb075..157ab115b1b3df 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -551,6 +551,12 @@ void TransformationsPipeline::apply(std::shared_ptr func) { if (use_xattention) use_xattention = check_xattn_gpu_compatibility(); + // KVCache layout with default attention - + // k: [num_blocks, num_kv_heads, head_size, block_size(16)] + // v: [num_blocks, num_kv_heads, block_size(16), head_size] + // KVCache layout with XAttention - + // k: [num_blocks, num_kv_heads, block_size(256), head_size] + // v: [num_blocks, num_kv_heads, block_size(256), head_size] ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); From 472f774ad10a9b3d5f01b4182acd589eeb887d91 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Mon, 13 Oct 2025 16:31:06 +0800 Subject: [PATCH 55/96] Remove CM PA tests Signed-off-by: Zhai, Xuejun --- .../unit/test_cases/paged_attention_gpu_test.cpp | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 7947a4e7305e5d..b5774ff8ba1beb 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1273,21 +1273,7 @@ const auto DYNAMIC_INPUT_PAD = true; const auto ENABLE_FA_V2 = false; const auto DISABLE_FA_V2 = true; - -#if ENABLE_PA_CM_PATH -INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ - /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ - paged_attention_test_params{ {{32, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - - paged_attention_test_params{ {{1, 31}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 32}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 1023}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 127}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 129}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 32}}, 28, 128, 128, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token -})); -#else +#ifndef ENABLE_PA_CM_PATH INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ /* with scores output, use SnapKV */ paged_attention_test_params{ {{10, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token From 1fdcd3c505f17b04059e62f1ca8b546bb92354d7 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 13 Oct 2025 17:08:17 +0800 Subject: [PATCH 56/96] refactor: use paged_attention::block_size_xattn instead of hardcode number --- src/plugins/intel_gpu/src/graph/paged_attention.cpp | 3 --- src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp | 2 +- .../intel_gpu/src/plugin/transformations_pipeline.cpp | 8 ++++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index f38859686b1e03..ca7a3907d1ebe4 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -11,9 +11,6 @@ namespace cldnn { GPU_DEFINE_PRIMITIVE_TYPE_ID(paged_attention) -constexpr size_t paged_attention::block_size; -constexpr size_t paged_attention::block_size_xattn; - layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*node*/, kernel_impl_params const& impl_param) { auto out_layout = impl_param.get_input_layout(0); diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 3f3c628b02c2e1..0e21f395beebd8 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -39,7 +39,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared // We may fallback to dense attention mode if xattn is not supported by either GPU archieture or compiler. // So we check block_size from value cache shape, instead of checking op input type, to determine if xatnn is enabled. - if (value_cache_ps[2].get_length() == 256) { + if (value_cache_ps[2].get_length() == cldnn::paged_attention::block_size_xattn) { prim.has_xattention = true; } const auto k_head_size_idx = prim.has_xattention ? 3 : 2; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 157ab115b1b3df..cd7d5295b76431 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -562,19 +562,19 @@ void TransformationsPipeline::apply(std::shared_ptr func) { kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); kv_cache_config.inferencePrecision = infer_precision; if (use_xattention) { - kv_cache_config.keyCacheBlockSize = 256; + kv_cache_config.keyCacheBlockSize = cldnn::paged_attention::block_size_xattn; kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; // default dim order of [num_blocks, num_kv_heads, block_size, head_size] } else { - kv_cache_config.keyCacheBlockSize = 16; + kv_cache_config.keyCacheBlockSize = cldnn::paged_attention::block_size; kv_cache_config.keyCacheDimOrder = {0, 1, 3, 2}; } kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; if (use_xattention) { - kv_cache_config.valueCacheBlockSize = 256; + kv_cache_config.valueCacheBlockSize = cldnn::paged_attention::block_size_xattn; kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; } else { - kv_cache_config.valueCacheBlockSize = 16; + kv_cache_config.valueCacheBlockSize = cldnn::paged_attention::block_size; kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; } kv_cache_config.valueCacheQuantBychannel = false; From a62fd1b44e096f477e455b81dce479ff914d6f8f Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 14 Oct 2025 00:04:53 +0800 Subject: [PATCH 57/96] worksgit status git status --- .../include/openvino/reference/x.bakup.cpp | 528 +++++ .../include/openvino/reference/xattention.hpp | 1985 +++++++++++++++-- .../src/graph/impls/cm/paged_attention.cpp | 60 +- .../unit/test_cases/xattention_gpu_test.cpp | 274 ++- .../test_utils/paged_attention_gpu_test.hpp | 44 + 5 files changed, 2609 insertions(+), 282 deletions(-) create mode 100644 src/core/reference/include/openvino/reference/x.bakup.cpp diff --git a/src/core/reference/include/openvino/reference/x.bakup.cpp b/src/core/reference/include/openvino/reference/x.bakup.cpp new file mode 100644 index 00000000000000..9e69cfabcf1816 --- /dev/null +++ b/src/core/reference/include/openvino/reference/x.bakup.cpp @@ -0,0 +1,528 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "openvino/reference/divide.hpp" +#include "openvino/reference/matmul.hpp" +#include "openvino/reference/softmax.hpp" +#include "openvino/reference/transpose.hpp" +#include "openvino/runtime/tensor.hpp" + +namespace ov::reference { + +using XAttentionBlockIndex = + std::pair; // .first is the *query* dimension block index, .second is *key* +using XAttentionRetainedBlockIndices = std::set; +using XAttentionRetainedBlockIndicesForAllHeads = std::vector; + +/** @brief Reference implementation of the XAttention sparse attention prefill mechanism + * (https://arxiv.org/abs/2503.16428) */ +template +class XAttentionBlockSelector { +public: + /** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the + * smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of + * attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of + * the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 + * corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained. + * @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension, + * key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks + * according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with + * the same granularity, i.e. the resulting blocks have equal size on both dimensions. Essentially `block_size` + * defines the granularity of the eventual sparse attention computations. Must be a multiple of `stride`. + * @param stride The stride at which the full attention matrix is subsampled in a block-antidiagonal fashion to + * estimate the block importance. Note that the full attention matrix is not computed, instead the original query + * and key matrices are reshaped appropriately so that only the necessary elements are computed. Ideally, the + * computational complexity of the entire block estimation operation is `stride` times lower than the full attention + * matrix computation. + * */ + XAttentionBlockSelector(double threshold, size_t block_size, size_t stride) + : m_threshold(threshold), + m_block_size(block_size), + m_stride(stride) { + OPENVINO_ASSERT(m_block_size % m_stride == 0); + } + + /** Assuming the input tensor is either a query tensor or key tensor, reshapes it in a diagonal or antidiagonal + * fashion as appropriate so that the resulting matrices could be used to compute the block-antidiagonal subset of + * the attention matrix in further operations. For the query tensor, the antidiagonal reshaping should be applied, + * and diagonal - for the key tensor. Note that for the diagonal reshaping the data layout is effectively unchanged + * and only the shape can be adjusted in the efficient implementation of the same operation in HW. + * @param input_data Pointer to the input tensor data (query or key) + * @param input_shape Shape of the input tensor data (query or key). Expected shape is [num_heads, num_tokens, + * head_size], where `num_tokens` must be a multiple of `stride`. + * @param output_data Pointer to the output tensor data (reshaped query or key storage) + * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_tokens / stride, head_size * + * stride] + * @param is_antidiagonal Whether to reshape antidiagonally (true) or diagonally (false). Use `true` for query + * tensor and `false` for key tensor. + */ + void diagonal_reshape(const T* input_data, + const Shape& input_shape, + T* output_data, + const Shape& out_shape, + bool is_antidiagonal) { + OPENVINO_ASSERT(input_shape.size() == 3); // [num_heads, num_tokens, head_size] + OPENVINO_ASSERT(out_shape.size() == 3); + OPENVINO_ASSERT(input_shape[0] == out_shape[0]); + OPENVINO_ASSERT(input_shape[1] % m_stride == 0); + OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); + OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); + + size_t num_stride_steps = input_shape[1] / m_stride; + for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) { + size_t head_offset = head_idx * input_shape[1] * input_shape[2]; + for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) { + for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) { + size_t input_offset = head_offset; + size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2]; + if (is_antidiagonal) { + input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2]; + } else { + input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2]; + } + std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T)); + } + } + } + } + + /** Performs a matrix multiplication on the input tensors Q and K and scales the result in a typical attention op + * fashion, i.e. Q @ K^T / (sqrt(D) * S). Additionally rescales by the stride value, as compared to the regular + * attention. + * @param reshaped_query_data Pointer to the reshaped query input. + * @param reshaped_key_data Pointer to the reshaped key input. + * @param reshaped_query_shape Shape of the reshaped query input data. Expected shape is [num_heads, + * num_query_tokens / stride, head_size * stride]. + * @param reshaped_key_shape Shape of the reshaped key input data. Expected shape is [num_heads, num_key_tokens / + * stride, head_size * stride]. + * @param out Pointer to the output tensor data (attention logit scores) + * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / stride, + * num_key_tokens / stride] + */ + void transpose_matmul_scale(const T* reshaped_query_data, + const T* reshaped_key_data, + const Shape& reshaped_query_shape, + const Shape& reshaped_key_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(reshaped_key_shape.size() == 3); + OPENVINO_ASSERT(reshaped_query_shape.size() == 3); + OPENVINO_ASSERT(reshaped_query_shape[0] == reshaped_key_shape[0]); + OPENVINO_ASSERT(reshaped_query_shape[2] == reshaped_key_shape[2]); + + OPENVINO_ASSERT(out_shape.size() == 3); + OPENVINO_ASSERT(out_shape[0] == reshaped_query_shape[0]); + OPENVINO_ASSERT(out_shape[1] == reshaped_query_shape[1]); + OPENVINO_ASSERT(out_shape[2] == reshaped_key_shape[1]); + + ov::reference::matmul(reshaped_query_data, + reshaped_key_data, + out, + reshaped_query_shape, + reshaped_key_shape, + out_shape, + /* transpose_arg0 = */ false, + /* transpose_arg1 = */ true); + + size_t out_size = out_shape[0] * out_shape[1] * out_shape[2]; + + for (size_t i = 0; i < out_size; i++) { + // The D in the formula above refers to the original head dimension, while + // reshaped_query_shape[2] had been scaled in the process of reshaping, therefore + // the formula is also adjusted: + out[i] = out[i] / std::sqrt(reshaped_query_shape[2] * m_stride); + } + } + + /** Performs a softmax operation on the last dimension of the rank-3 input tensor. + * @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax). + * @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens / + * stride, num_key_tokens / stride]. + * @param out Pointer to the output tensor data (attention scores) + * @param out_shape Shape of the output tensor data. Expected shape is strictly equal to + * `reshaped_qk_product_shape`. + */ + void softmax(const T* reshaped_qk_product_data, + const Shape& reshaped_qk_product_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); + OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); + ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2}); + } + + /** Divides the input rank-3 tensor into blocks along last two dimensions, performs the addition of the values + * inside each block and outputs each block sum into corresponding positions in the output tensor downsampled along + * the same dimensions. The output tensor dimensions are such that the query and key token dimensions are + * downsampled by `block_size` when compared to the *original* query and key tensors. + * @param attention_scores_data Pointer to the attention score input. + * @param attention_score_shape Shape of the attention score input tensor. Expected shape is [num_heads, + * num_query_tokens / stride, num_key_tokens / stride], where `num_query_tokens` and `num_key_tokens` must be + * multiples of `block_size`. + * @param out Pointer to the output tensor data (block sums) + * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size, + * num_key_tokens / block_size]. + */ + void block_sum_attention_scores(const T* attention_scores_data, + const Shape& attention_scores_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(attention_scores_shape.size() == 3); // [num_heads, query_antidiagonals, key_antidiagonals] + size_t antidiagonals_per_xattention_block = m_block_size / m_stride; + OPENVINO_ASSERT(attention_scores_shape[1] % antidiagonals_per_xattention_block == 0); + OPENVINO_ASSERT(attention_scores_shape[2] % antidiagonals_per_xattention_block == 0); + + OPENVINO_ASSERT(out_shape[0] == attention_scores_shape[0]); + OPENVINO_ASSERT(out_shape[1] == + attention_scores_shape[1] / antidiagonals_per_xattention_block); // query length, blocked + OPENVINO_ASSERT(out_shape[2] == + attention_scores_shape[2] / antidiagonals_per_xattention_block); // key length, blocked + + std::memset(out, 0, out_shape[0] * out_shape[1] * out_shape[2] * sizeof(T)); + + for (size_t head_idx = 0; head_idx < attention_scores_shape[0]; head_idx++) { + size_t in_head_offset = head_idx * attention_scores_shape[1] * attention_scores_shape[2]; + size_t out_head_offset = head_idx * out_shape[1] * out_shape[2]; + for (size_t query_len_idx = 0; query_len_idx < attention_scores_shape[1]; query_len_idx++) { + for (size_t key_len_idx = 0; key_len_idx < attention_scores_shape[2]; key_len_idx++) { + size_t query_block_idx = query_len_idx / antidiagonals_per_xattention_block; + size_t key_block_idx = key_len_idx / antidiagonals_per_xattention_block; + auto target_block_sum_ptr = out + out_head_offset + query_block_idx * out_shape[2] + key_block_idx; + *target_block_sum_ptr += *(attention_scores_data + in_head_offset + + query_len_idx * attention_scores_shape[2] + key_len_idx); + } + } + } + } + + /** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, + * so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the + * total element sum. + * @param blocked_scores_data Pointer to the blocked score input. + * @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads, + * num_query_tokens / block_size, num_key_tokens / block_size] + * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block + * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks + * corresponding to the property described above. + */ +// template +void print_blocked_attention_scores(const T* data, + size_t num_heads, + size_t num_q_blocks, + size_t num_k_blocks) { + std::cout << "blocked_attention_scores shape: [" + << num_heads << ", " << num_q_blocks << ", " << num_k_blocks << "]\n"; + + for (size_t h = 0; h < num_heads; ++h) { + std::cout << "Head " << h << ":\n"; + std::cout << std::setw(8) << ""; + for (size_t k = 0; k < num_k_blocks; ++k) { + std::cout << std::setw(12) << ("K" + std::to_string(k)); + } + std::cout << "\n"; + + for (size_t q = 0; q < num_q_blocks; ++q) { + std::cout << std::setw(6) << ("Q" + std::to_string(q)) << " "; + double row_sum = 0.0; + for (size_t k = 0; k < num_k_blocks; ++k) { + size_t idx = h * (num_q_blocks * num_k_blocks) + q * num_k_blocks + k; + double v = static_cast(static_cast(*(data + idx))); + row_sum += v; + std::cout << std::setw(12) << std::fixed << std::setprecision(6) << v; + } + std::cout << " sum=" << std::fixed << std::setprecision(6) << row_sum << "\n"; + } + std::cout << std::flush; + } +} +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( +// const T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); +// // [num_heads, num_blocks_in_query, num_blocks_in_key] + +// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + +// struct IndexAndScore { +// XAttentionBlockIndex idx; +// T score; +// }; + +// const size_t num_heads = blocked_attention_scores_shape[0]; +// const size_t num_q_blocks = blocked_attention_scores_shape[1]; +// const size_t num_k_blocks = blocked_attention_scores_shape[2]; +// // print_blocked_attention_scores(blocked_attention_scores_data, num_heads, num_q_blocks, num_k_blocks); + +// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { +// size_t head_offset = head_idx * num_q_blocks * num_k_blocks; + +// for (size_t q_block_idx = 0; q_block_idx < num_q_blocks; q_block_idx++) { +// std::vector indices_and_scores; +// indices_and_scores.reserve(num_k_blocks); + +// double total_sum = 0.0; + +// for (size_t k_block_idx = 0; k_block_idx < num_k_blocks; k_block_idx++) { +// size_t target_offset = head_offset + q_block_idx * num_k_blocks + k_block_idx; +// T current_score = *(blocked_attention_scores_data + target_offset); +// indices_and_scores.push_back({{q_block_idx, k_block_idx}, current_score}); +// total_sum += current_score; +// } + +// double required_sum = m_threshold * total_sum; + +// std::sort(indices_and_scores.begin(), indices_and_scores.end(), +// [](const IndexAndScore& a, const IndexAndScore& b) { +// return a.score > b.score; +// }); + +// std::vector shifted_cumsum(num_k_blocks, 0.0); +// for (size_t i = 1; i < num_k_blocks; i++) { +// shifted_cumsum[i] = shifted_cumsum[i - 1] + indices_and_scores[i - 1].score; +// } + +// for (size_t i = 0; i < num_k_blocks; i++) { +// if (shifted_cumsum[i] < required_sum) { +// retval[head_idx].insert(indices_and_scores[i].idx); +// } +// } +// } +// } + +// return retval; +// } + + + +void dump_blocked_attention_scores_bin(const std::string& filename, + const float* data, + size_t num_heads, + size_t num_q_blocks, + size_t num_k_blocks) { + size_t total_elems = num_heads * num_q_blocks * num_k_blocks; + std::ofstream ofs(filename, std::ios::binary); + if (!ofs) { + std::cerr << "Failed to open file for writing: " << filename << std::endl; + return; + } + ofs.write(reinterpret_cast(data), total_elems * sizeof(float)); + ofs.close(); + + std::cout << "✅ Dumped blocked_attention_scores to: " << filename + << " (" << total_elems << " elements, " + << sizeof(float) * total_elems / 1024.0 << " KB)\n"; +} + +// template +XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( + const T* blocked_attention_scores_data, + const Shape& blocked_attention_scores_shape) { + OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); + // [num_heads, num_blocks_in_query, num_blocks_in_key] + + auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + + struct IndexAndScore { + XAttentionBlockIndex idx; + T score; + }; + + const size_t num_heads = blocked_attention_scores_shape[0]; + const size_t num_q_blocks = blocked_attention_scores_shape[1]; + const size_t num_k_blocks = blocked_attention_scores_shape[2]; + print_blocked_attention_scores(blocked_attention_scores_data, num_heads, num_q_blocks, num_k_blocks); + + size_t total_elems = num_heads * num_q_blocks * num_k_blocks; + std::vector data_f32(total_elems); + for (size_t i = 0; i < total_elems; i++) + data_f32[i] = static_cast(blocked_attention_scores_data[i]); + dump_blocked_attention_scores_bin("blocked_attention_scores.bin", + data_f32.data(), num_heads, num_q_blocks, num_k_blocks); + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { + size_t head_offset = head_idx * num_q_blocks * num_k_blocks; + + for (size_t q_block_idx = 0; q_block_idx < num_q_blocks; q_block_idx++) { + std::vector indices_and_scores; + indices_and_scores.reserve(num_k_blocks); + + double total_sum = 0.0; + for (size_t k_block_idx = 0; k_block_idx < num_k_blocks; k_block_idx++) { + size_t offset = head_offset + q_block_idx * num_k_blocks + k_block_idx; + T score = *(blocked_attention_scores_data + offset); + indices_and_scores.push_back({{q_block_idx, k_block_idx}, score}); + total_sum += score; + } + + double required_sum = m_threshold * total_sum; + + // === 与 Python 一致:按 score 降序排序 === + std::sort(indices_and_scores.begin(), indices_and_scores.end(), + [](const IndexAndScore& a, const IndexAndScore& b) { + return a.score > b.score; + }); + + // === 模拟 Python 的 cumulative_sum_without_self === + // 即:每个元素的累积和是“之前所有元素的和”,自身不计入。 + std::vector shifted_cumsum(num_k_blocks, 0.0); + for (size_t i = 1; i < num_k_blocks; i++) { + shifted_cumsum[i] = shifted_cumsum[i - 1] + indices_and_scores[i - 1].score; + } + + // === 选择 cumulative_sum_without_self < required_sum 的 block === + for (size_t i = 0; i < num_k_blocks; i++) { + if (shifted_cumsum[i] < required_sum) { + retval[head_idx].insert(indices_and_scores[i].idx); + } + } + + // ✅ Python 中通常也会强制保留“自身 block”,即 (q_block_idx, q_block_idx) + retval[head_idx].insert({q_block_idx, q_block_idx}); + } + } + + return retval; +} + + + // XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, + // const Shape& blocked_attention_scores_shape) { + // OPENVINO_ASSERT(blocked_attention_scores_shape.size() == + // 3); // [num_heads, num_blocks_in_query, num_blocks_in_key] + + // auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + + // struct IndexAndScore { + // XAttentionBlockIndex idx; + // T score; + // bool operator<(const IndexAndScore& rhs) const { + // return score < rhs.score; + // } + // }; + + // for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) { + // size_t head_offset = head_idx * blocked_attention_scores_shape[1] * blocked_attention_scores_shape[2]; + // std::priority_queue indices_and_scores_queue; + // double total_sum = 0.0; + // for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) { + + // for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) { + // size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx; + // T current_score = *(blocked_attention_scores_data + target_offset); + // indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score}); + // total_sum += current_score; + // } + // } + // double cumsum = 0.0; + // double required_sum = m_threshold * total_sum; + // while (cumsum < required_sum && !indices_and_scores_queue.empty()) { + // auto index_and_largest_score = indices_and_scores_queue.top(); + // indices_and_scores_queue.pop(); + // cumsum += index_and_largest_score.score; + // retval[head_idx].insert(index_and_largest_score.idx); + // } + // } + // return retval; + // } + + /** Applies XAttention to the provided query and key matrices, returning the subset of the most important blocks for + * each attention head, according to the configured block size and threshold, which are to be preserved in the + * subsequent sparse attention computation. + * @param query_data Pointer to the query input tensor data + * @param query_shape Shape of the query input tensor data. Expected shape is [num_heads, num_query_tokens, + * head_size], where `num_query_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if + * necessary to do so in the real-world scenario. + * @param key_data Pointer to the key input tensor data + * @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size], + * where `num_key_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if necessary to + * do so in the real-world scenario. + * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block + * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks that + * must be preserved in the sparse attention computation. Indices are given in units of XAttention-specific + * `block_size` (as configured), which may differ from the block size in the paged attention implementation. + */ + XAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, + const Shape& query_shape, + const T* key_data, + const Shape& key_shape) { + OPENVINO_ASSERT(query_shape.size() == 3); // [num_heads, query_token_len, head_dim] + OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim] + + OPENVINO_ASSERT(key_shape[0] == query_shape[0]); + OPENVINO_ASSERT(key_shape[2] == query_shape[2]); + + OPENVINO_ASSERT(query_shape[1] % m_stride == 0); + OPENVINO_ASSERT(key_shape[1] % m_stride == 0); + + OPENVINO_ASSERT(query_shape[1] % m_block_size == 0); + OPENVINO_ASSERT(key_shape[1] % m_block_size == 0); + + Shape reshaped_query_shape = {query_shape[0], query_shape[1] / m_stride, query_shape[2] * m_stride}; + auto q_buf = allocate_buf(reshaped_query_shape); + diagonal_reshape(query_data, query_shape, q_buf.get(), reshaped_query_shape, /* is_antidiagonal = */ true); + + Shape reshaped_key_shape = {key_shape[0], key_shape[1] / m_stride, key_shape[2] * m_stride}; + auto k_buf = allocate_buf(reshaped_key_shape); + diagonal_reshape(key_data, key_shape, k_buf.get(), reshaped_key_shape, /* is_antidiagonal = */ false); + + Shape transpose_matmul_scaled_shape = {key_shape[0], query_shape[1] / m_stride, key_shape[1] / m_stride}; + auto qk_buf = allocate_buf(transpose_matmul_scaled_shape); + transpose_matmul_scale(q_buf.get(), + k_buf.get(), + reshaped_query_shape, + reshaped_key_shape, + qk_buf.get(), + transpose_matmul_scaled_shape); + q_buf.reset(); + k_buf.reset(); + + Shape attention_scores_shape = transpose_matmul_scaled_shape; + auto attn_score_buf = allocate_buf(attention_scores_shape); + softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); + qk_buf.reset(); + + size_t antidiagonals_per_xattention_block = m_block_size / m_stride; + Shape block_sum_shape = {attention_scores_shape[0], + attention_scores_shape[1] / antidiagonals_per_xattention_block, + attention_scores_shape[2] / antidiagonals_per_xattention_block}; + auto block_sum_buf = allocate_buf(block_sum_shape); + block_sum_attention_scores(attn_score_buf.get(), attention_scores_shape, block_sum_buf.get(), block_sum_shape); + attn_score_buf.reset(); + + auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); + block_sum_buf.reset(); + + return selected_block_indices; + } + + /** + * @param shape Shape of a tensor + * @return A shared_ptr owning a buffer that can be used to store tensor data for the given shape. + * */ + std::shared_ptr allocate_buf(const Shape& shape) { + return std::shared_ptr(new T[ov::shape_size(shape)]); + } + + /** + * @param token_length An integer value + * @return The closest multiple of `block_size` to `token_length`, rounding up. + * */ + size_t pad_to_block(size_t token_length) { + return (token_length + m_block_size - 1) / m_block_size * m_block_size; + } + + double m_threshold; + size_t m_block_size; + size_t m_stride; +}; + +} // namespace ov::reference \ No newline at end of file diff --git a/src/core/reference/include/openvino/reference/xattention.hpp b/src/core/reference/include/openvino/reference/xattention.hpp index 49e01042417caf..0bb181e2460e23 100644 --- a/src/core/reference/include/openvino/reference/xattention.hpp +++ b/src/core/reference/include/openvino/reference/xattention.hpp @@ -4,8 +4,10 @@ #pragma once +#include #include #include +#include #include #include @@ -17,32 +19,18 @@ namespace ov::reference { +using Shape = std::vector; + using XAttentionBlockIndex = std::pair; // .first is the *query* dimension block index, .second is *key* using XAttentionRetainedBlockIndices = std::set; using XAttentionRetainedBlockIndicesForAllHeads = std::vector; /** @brief Reference implementation of the XAttention sparse attention prefill mechanism - * (https://arxiv.org/abs/2503.16428) */ + *[](https://arxiv.org/abs/2503.16428) */ template class XAttentionBlockSelector { public: - /** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the - * smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of - * attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of - * the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 - * corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained. - * @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension, - * key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks - * according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with - * the same granularity, i.e. the resulting blocks have equal size on both dimensions. Essentially `block_size` - * defines the granularity of the eventual sparse attention computations. Must be a multiple of `stride`. - * @param stride The stride at which the full attention matrix is subsampled in a block-antidiagonal fashion to - * estimate the block importance. Note that the full attention matrix is not computed, instead the original query - * and key matrices are reshaped appropriately so that only the necessary elements are computed. Ideally, the - * computational complexity of the entire block estimation operation is `stride` times lower than the full attention - * matrix computation. - * */ XAttentionBlockSelector(double threshold, size_t block_size, size_t stride) : m_threshold(threshold), m_block_size(block_size), @@ -50,26 +38,12 @@ class XAttentionBlockSelector { OPENVINO_ASSERT(m_block_size % m_stride == 0); } - /** Assuming the input tensor is either a query tensor or key tensor, reshapes it in a diagonal or antidiagonal - * fashion as appropriate so that the resulting matrices could be used to compute the block-antidiagonal subset of - * the attention matrix in further operations. For the query tensor, the antidiagonal reshaping should be applied, - * and diagonal - for the key tensor. Note that for the diagonal reshaping the data layout is effectively unchanged - * and only the shape can be adjusted in the efficient implementation of the same operation in HW. - * @param input_data Pointer to the input tensor data (query or key) - * @param input_shape Shape of the input tensor data (query or key). Expected shape is [num_heads, num_tokens, - * head_size], where `num_tokens` must be a multiple of `stride`. - * @param output_data Pointer to the output tensor data (reshaped query or key storage) - * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_tokens / stride, head_size * - * stride] - * @param is_antidiagonal Whether to reshape antidiagonally (true) or diagonally (false). Use `true` for query - * tensor and `false` for key tensor. - */ void diagonal_reshape(const T* input_data, const Shape& input_shape, T* output_data, const Shape& out_shape, bool is_antidiagonal) { - OPENVINO_ASSERT(input_shape.size() == 3); // [num_heads, num_tokens, head_size] + OPENVINO_ASSERT(input_shape.size() == 3); OPENVINO_ASSERT(out_shape.size() == 3); OPENVINO_ASSERT(input_shape[0] == out_shape[0]); OPENVINO_ASSERT(input_shape[1] % m_stride == 0); @@ -94,19 +68,56 @@ class XAttentionBlockSelector { } } - /** Performs a matrix multiplication on the input tensors Q and K and scales the result in a typical attention op - * fashion, i.e. Q @ K^T / (sqrt(D) * S). Additionally rescales by the stride value, as compared to the regular - * attention. - * @param reshaped_query_data Pointer to the reshaped query input. - * @param reshaped_key_data Pointer to the reshaped key input. - * @param reshaped_query_shape Shape of the reshaped query input data. Expected shape is [num_heads, - * num_query_tokens / stride, head_size * stride]. - * @param reshaped_key_shape Shape of the reshaped key input data. Expected shape is [num_heads, num_key_tokens / - * stride, head_size * stride]. - * @param out Pointer to the output tensor data (attention logit scores) - * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / stride, - * num_key_tokens / stride] - */ +void diagonal_reshape_kdb1_no_batch( + const T* input_data, // 原始 query buffer + const std::vector& input_shape, // [H, Q_orig, dim] + T* output_data, // 输出 q_buf + const std::vector& output_shape) +{ + + size_t H = input_shape[0]; + size_t Q_orig = input_shape[1]; + size_t dim = input_shape[2]; + size_t Q_new = output_shape[1]; + + + for (size_t h = 0; h < H; ++h) { + size_t head_in_offset = h * Q_orig * dim; + size_t head_out_offset = h * Q_new * m_stride * dim; + + for (size_t s = 0; s < m_stride; ++s) { + for (size_t q = 0; q < Q_new; ++q) { + size_t in_idx = head_in_offset + (m_stride - 1 - s + q * m_stride) * dim; + size_t out_idx = head_out_offset + q * m_stride * dim + s * dim; + std::memcpy(output_data + out_idx, input_data + in_idx, dim * sizeof(T)); + } + } + } +} + void diagonal_reshape_q(const T* input_data, + const Shape& input_shape, + T* output_data, + const Shape& out_shape, + bool is_antidiagonal) { + size_t B = 1; + size_t H = input_shape[0]; + int Q = input_shape[1]; + int dim = input_shape[2]; + for (size_t b = 0; b < B; ++b) { + for (size_t h = 0; h < H; ++h) { + size_t head_offset_in = b * H * Q * dim + h * Q * dim; + size_t head_offset_out = b * H * Q * dim * m_stride + h * Q * dim * m_stride; + for (size_t q = 0; q < Q / m_stride; ++q) { + for (size_t s = 0; s < m_stride; ++s) { + size_t in_idx = head_offset_in + (Q / m_stride) * s + q; // 交错取值 + size_t out_idx = head_offset_out + q * m_stride * dim + s * dim; // 拼接到最后维度 + std::memcpy(output_data + out_idx, input_data + in_idx * dim, dim * sizeof(T)); + } + } + } + } + } + void transpose_matmul_scale(const T* reshaped_query_data, const T* reshaped_key_data, const Shape& reshaped_query_shape, @@ -129,27 +140,82 @@ class XAttentionBlockSelector { reshaped_query_shape, reshaped_key_shape, out_shape, - /* transpose_arg0 = */ false, - /* transpose_arg1 = */ true); + false, + true); size_t out_size = out_shape[0] * out_shape[1] * out_shape[2]; for (size_t i = 0; i < out_size; i++) { - // The D in the formula above refers to the original head dimension, while - // reshaped_query_shape[2] had been scaled in the process of reshaping, therefore - // the formula is also adjusted: out[i] = out[i] / std::sqrt(reshaped_query_shape[2] * m_stride); } } - /** Performs a softmax operation on the last dimension of the rank-3 input tensor. - * @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax). - * @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens / - * stride, num_key_tokens / stride]. - * @param out Pointer to the output tensor data (attention scores) - * @param out_shape Shape of the output tensor data. Expected shape is strictly equal to - * `reshaped_qk_product_shape`. - */ +void softmax_ww(const T* reshaped_qk_product_data, + const Shape& reshaped_qk_product_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); + OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); + + size_t num_heads = reshaped_qk_product_shape[0]; + size_t q_blocks = reshaped_qk_product_shape[1]; + size_t k_blocks = reshaped_qk_product_shape[2]; + + std::vector temp_in(q_blocks * k_blocks); + std::vector temp_out(q_blocks * k_blocks); + + for (size_t h = 0; h < num_heads; ++h) { + for (size_t q = 0; q < q_blocks; ++q) { + // 将输入从 half 转为 float + for (size_t k = 0; k < k_blocks; ++k) { + size_t idx = h * q_blocks * k_blocks + q * k_blocks + k; + temp_in[k] = static_cast(reshaped_qk_product_data[idx]); + } + + // 数值稳定 softmax: 先减去最大值 + float max_val = *std::max_element(temp_in.begin(), temp_in.end()); + float sum_exp = 0.f; + for (size_t k = 0; k < k_blocks; ++k) { + temp_out[k] = std::exp(temp_in[k] - max_val); + sum_exp += temp_out[k]; + } + + // 归一化 + float inv_sum = 1.f / (sum_exp + 1e-12f); + for (size_t k = 0; k < k_blocks; ++k) { + size_t idx = h * q_blocks * k_blocks + q * k_blocks + k; + out[idx] = static_cast(temp_out[k] * inv_sum); + } + } + } +} + +void softmax_fp32(const T* input, const Shape& shape, T* output, const Shape& out_shape) { + OPENVINO_ASSERT(shape.size() == 3); + size_t dim0 = shape[0], dim1 = shape[1], dim2 = shape[2]; + + std::vector temp(dim2); + for (size_t i = 0; i < dim0 * dim1; ++i) { + size_t offset = i * dim2; + + // 1. 转为 float32 + for (size_t j = 0; j < dim2; ++j) + temp[j] = static_cast(input[offset + j]); + + // 2. 稳定 softmax + float max_val = *std::max_element(temp.begin(), temp.end()); + float sum_exp = 0.f; + for (float& v : temp) { + v = std::exp(v - max_val); + sum_exp += v; + } + + // 3. 写回 + for (size_t j = 0; j < dim2; ++j) + output[offset + j] = static_cast(temp[j] / sum_exp); + } +} + void softmax(const T* reshaped_qk_product_data, const Shape& reshaped_qk_product_shape, T* out, @@ -159,32 +225,18 @@ class XAttentionBlockSelector { ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2}); } - /** Divides the input rank-3 tensor into blocks along last two dimensions, performs the addition of the values - * inside each block and outputs each block sum into corresponding positions in the output tensor downsampled along - * the same dimensions. The output tensor dimensions are such that the query and key token dimensions are - * downsampled by `block_size` when compared to the *original* query and key tensors. - * @param attention_scores_data Pointer to the attention score input. - * @param attention_score_shape Shape of the attention score input tensor. Expected shape is [num_heads, - * num_query_tokens / stride, num_key_tokens / stride], where `num_query_tokens` and `num_key_tokens` must be - * multiples of `block_size`. - * @param out Pointer to the output tensor data (block sums) - * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size, - * num_key_tokens / block_size]. - */ void block_sum_attention_scores(const T* attention_scores_data, const Shape& attention_scores_shape, T* out, const Shape& out_shape) { - OPENVINO_ASSERT(attention_scores_shape.size() == 3); // [num_heads, query_antidiagonals, key_antidiagonals] + OPENVINO_ASSERT(attention_scores_shape.size() == 3); size_t antidiagonals_per_xattention_block = m_block_size / m_stride; OPENVINO_ASSERT(attention_scores_shape[1] % antidiagonals_per_xattention_block == 0); OPENVINO_ASSERT(attention_scores_shape[2] % antidiagonals_per_xattention_block == 0); OPENVINO_ASSERT(out_shape[0] == attention_scores_shape[0]); - OPENVINO_ASSERT(out_shape[1] == - attention_scores_shape[1] / antidiagonals_per_xattention_block); // query length, blocked - OPENVINO_ASSERT(out_shape[2] == - attention_scores_shape[2] / antidiagonals_per_xattention_block); // key length, blocked + OPENVINO_ASSERT(out_shape[1] == attention_scores_shape[1] / antidiagonals_per_xattention_block); + OPENVINO_ASSERT(out_shape[2] == attention_scores_shape[2] / antidiagonals_per_xattention_block); std::memset(out, 0, out_shape[0] * out_shape[1] * out_shape[2] * sizeof(T)); @@ -203,230 +255,1743 @@ class XAttentionBlockSelector { } } - /** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, - * so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the - * total element sum. - * @param blocked_scores_data Pointer to the blocked score input. - * @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads, - * num_query_tokens / block_size, num_key_tokens / block_size] - * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block - * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks - * corresponding to the property described above. - */ -// template -// void print_blocked_attention_scores(const T* data, -// size_t num_heads, -// size_t num_q_blocks, -// size_t num_k_blocks) { -// std::cout << "blocked_attention_scores shape: [" -// << num_heads << ", " << num_q_blocks << ", " << num_k_blocks << "]\n"; - -// for (size_t h = 0; h < num_heads; ++h) { -// std::cout << "Head " << h << ":\n"; -// std::cout << std::setw(8) << ""; -// for (size_t k = 0; k < num_k_blocks; ++k) { -// std::cout << std::setw(12) << ("K" + std::to_string(k)); +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( +// const std::vector& input_tensor, // flattened [batch, head, q_block_num, k_block_num] +// size_t batch_size, +// size_t num_heads, +// size_t q_block_num, +// size_t k_block_num, +// double threshold, +// size_t block_size, +// size_t stride, +// bool causal = true) { + +// XAttentionRetainedBlockIndicesForAllHeads retained_blocks(num_heads); + +// for (size_t b = 0; b < batch_size; ++b) { +// for (size_t h = 0; h < num_heads; ++h) { +// auto& retained = retained_blocks[h]; +// const size_t base_offset = ((b * num_heads + h) * q_block_num) * k_block_num; + +// for (size_t q_block_idx = 0; q_block_idx < q_block_num; ++q_block_idx) { +// size_t diagonal_k = q_block_idx; +// std::vector> others; + +// // 1. 收集当前 query block 对所有 key block 的分数 +// double row_sum = 0.0; +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { +// double score = input_tensor[base_offset + q_block_idx * k_block_num + k_block_idx]; +// if (std::isnan(score) || std::isinf(score)) +// score = 0.0; +// row_sum += score; +// if (k_block_idx != 0 && k_block_idx != diagonal_k) { +// others.emplace_back(score, k_block_idx); +// } +// } + +// // Debug: 打印 row_sum 和 q_block_idx +// /* +// if (h == 0) +// std::cout << "[Debug] q=" << q_block_idx +// << " row_sum=" << row_sum << " others=" << others.size() << "\n"; +// */ + +// if (row_sum <= 0.0) +// continue; + +// // 2. 强制保留 (q, 0) 和 diagonal +// retained.insert({q_block_idx, 0}); +// retained.insert({q_block_idx, diagonal_k}); + +// // 3. 按分数降序排列 others +// std::sort(others.begin(), others.end(), +// [](const auto& a, const auto& b) { return a.first > b.first; }); + +// // 4. 计算累计阈值 +// double required_sum = threshold * row_sum; +// double cumsum = 0.0; + +// std::priority_queue pq; + +// // ✅ 修复点:原代码用了 others.size() - 2,导致丢项。应当 push 全部候选。 +// for (size_t i = 0; i < others.size(); ++i) { +// pq.push({others[i].second, others[i].first}); +// } + +// // Debug: 打印 top 若干项 +// /* +// if (h == 0 && (q_block_idx == 6 || q_block_idx == 7)) { +// std::cout << "[Debug] q=" << q_block_idx << " others(sorted): "; +// for (size_t i = 0; i < std::min(others.size(), 8); ++i) +// std::cout << "(" << others[i].second << "," << std::fixed << std::setprecision(3) +// << others[i].first << ") "; +// std::cout << "\n"; +// } +// */ + +// // 5. 从大到小取,直到累计到阈值 +// while (!pq.empty() && cumsum < required_sum) { +// auto top = pq.top(); +// pq.pop(); +// cumsum += top.score; +// retained.insert({q_block_idx, top.index}); +// } + +// // Debug: 打印累计结果 +// /* +// if (h == 0 && (q_block_idx == 6 || q_block_idx == 7)) { +// std::cout << "[Debug] q=" << q_block_idx +// << " required=" << required_sum +// << " cumsum=" << cumsum +// << " retained=" << retained.size() << "\n"; +// } +// */ + +// // 6. causal mask:只保留 k <= q +// if (causal) { +// std::set> causal_retained; +// for (auto& kv : retained) { +// if (kv.second <= kv.first) +// causal_retained.insert(kv); +// } +// retained = std::move(causal_retained); +// } +// } // } -// std::cout << "\n"; +// } + +// return retained_blocks; +// } + + +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); + +// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + +// struct IndexAndScore { +// size_t k_block_idx; +// double score; +// bool operator<(const IndexAndScore& rhs) const { +// return score < rhs.score; +// } +// }; + +// size_t q_block_num = blocked_attention_scores_shape[1]; +// size_t k_block_num = blocked_attention_scores_shape[2]; +// size_t current_index = k_block_num - q_block_num; + +// for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) { +// auto& retained = retval[head_idx]; +// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { +// double row_sum = 0.0; +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// row_sum += static_cast(blocked_attention_scores_data[offset]); +// } + +// double required_sum = m_threshold * row_sum; +// double cumsum = 0.0; +// // Force include first +// size_t k_block_idx = 0; +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// double score = static_cast(blocked_attention_scores_data[offset]); +// cumsum += score; +// retained.insert({q_block_idx, k_block_idx}); +// // Force include diagonal +// size_t diagonal_k = current_index + q_block_idx; +// offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; +// score = static_cast(blocked_attention_scores_data[offset]); +// cumsum += score; +// retained.insert({q_block_idx, diagonal_k}); +// // Others + +// std::vector> others; +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { +// if (k_block_idx == 0 || k_block_idx == diagonal_k) +// continue; +// offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// double sc = static_cast(blocked_attention_scores_data[offset]); +// others.emplace_back(sc, k_block_idx); +// } + +// std::sort(others.begin(), others.end(), [](const auto& a, const auto& b) { +// return a.first > b.first; +// }); + +// std::priority_queue indices_and_scores_queue; + +// for (size_t i = 0; i < others.size() - 2; i++) { +// if (i >= others.size()) +// break; + +// indices_and_scores_queue.push({others[i].second, others[i].first}); +// } + +// while (cumsum < required_sum && !indices_and_scores_queue.empty()) { +// auto index_and_largest_score = indices_and_scores_queue.top(); + +// indices_and_scores_queue.pop(); + +// cumsum += index_and_largest_score.score; + +// retained.insert({q_block_idx, index_and_largest_score.k_block_idx}); +// } +// } + +// // Enforce causal + +// auto it = retained.begin(); + +// while (it != retained.end()) { +// size_t q = it->first; -// for (size_t q = 0; q < num_q_blocks; ++q) { -// std::cout << std::setw(6) << ("Q" + std::to_string(q)) << " "; +// size_t k = it->second; + +// if (k >= current_index && (k - current_index) > q) { +// it = retained.erase(it); + +// } else { +// ++it; +// } +// } +// } + +// return retval; +// } + +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); + +// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + +// size_t num_heads = blocked_attention_scores_shape[0]; +// size_t q_block_num = blocked_attention_scores_shape[1]; +// size_t k_block_num = blocked_attention_scores_shape[2]; + +// // keep the same current_index computation as original C++ (matches Python caller behavior) +// size_t current_index = k_block_num - q_block_num; + +// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { +// auto& retained = retval[head_idx]; + +// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { +// // --- 1) 读一行(q_block_idx)并计算 row_sum +// std::vector row(k_block_num); // double row_sum = 0.0; -// for (size_t k = 0; k < num_k_blocks; ++k) { -// size_t idx = h * (num_q_blocks * num_k_blocks) + q * num_k_blocks + k; -// double v = static_cast(static_cast(*(data + idx))); +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// double v = static_cast(blocked_attention_scores_data[offset]); +// if (std::isnan(v) || std::isinf(v)) +// v = 0.0; +// row[k_block_idx] = v; // row_sum += v; -// std::cout << std::setw(12) << std::fixed << std::setprecision(6) << v; // } -// std::cout << " sum=" << std::fixed << std::setprecision(6) << row_sum << "\n"; + +// double required_sum = m_threshold * row_sum; + +// // --- 2) 构造 forced mask(与 Python 中 mask 一致:k==0 与 diagonal_k) +// std::vector forced(k_block_num, 0); +// forced[0] = 1; +// size_t diagonal_k = current_index + q_block_idx; +// if (diagonal_k < k_block_num) +// forced[diagonal_k] = 1; + +// // --- 3) 计算 forced_sum(就是 torch.where(mask, input_tensor, 0).sum(...)) +// double forced_sum = 0.0; +// for (size_t k = 0; k < k_block_num; ++k) +// if (forced[k]) +// forced_sum += row[k]; + +// // --- 4) 构造 other_values = masked_fill(mask, 0) 并做降序排序(保留索引) +// std::vector> other_pairs; // (value, k_idx) +// other_pairs.reserve(k_block_num); +// for (size_t k = 0; k < k_block_num; ++k) { +// double val = forced[k] ? 0.0 : row[k]; +// other_pairs.emplace_back(val, k); +// } +// std::sort(other_pairs.begin(), other_pairs.end(), [](const auto& a, const auto& b) { +// return a.first > b.first; +// }); + +// // --- 5) 按 Python: 构造 sorted_values_final = [0, forced_sum, other_pairs[0..-3]] (即 sorted_values[:-2]) +// // 这样 final length == k_block_num(相同长度) +// std::vector sorted_values_cat; +// sorted_values_cat.reserve(k_block_num); +// sorted_values_cat.push_back(0.0); +// sorted_values_cat.push_back(forced_sum); +// size_t take = 0; +// if (k_block_num >= 2) { +// // other_pairs.size() == k_block_num +// // we need to append other_pairs[0 .. k_block_num-3] => count = k_block_num - 2 +// // but slice is other_pairs[:-2] -> indices [0 .. k_block_num-3] (count k_block_num-2) +// take = (k_block_num >= 2) ? (k_block_num - 2) : 0; +// } +// for (size_t i = 0; i < take; ++i) { +// sorted_values_cat.push_back(other_pairs[i].first); +// } +// // safety: if for some reason sizes mismatch, pad zeros to reach length k_block_num +// while (sorted_values_cat.size() < k_block_num) +// sorted_values_cat.push_back(0.0); + +// // --- 6) 构造 index_order == argsort(descending) of where(mask, BIG*(1+row), row) +// std::vector> index_pairs; +// index_pairs.reserve(k_block_num); +// const double BIG = 100000.0; // mirrors Python 100000*(1 + input_tensor) +// for (size_t k = 0; k < k_block_num; ++k) { +// double key = forced[k] ? (BIG * (1.0 + row[k])) : row[k]; +// index_pairs.emplace_back(key, k); +// } +// std::sort(index_pairs.begin(), index_pairs.end(), [](const auto& a, const auto& b) { +// return a.first > b.first; +// }); + +// // --- 7) 计算 cumulative_sum_without_self == cumsum( [0] + sorted_values_cat[0:-1] ) +// // 即 cumsum_before[pos] = sum(sorted_values_cat[0 .. pos-1]) +// std::vector cumsum_before(k_block_num, 0.0); +// double acc = 0.0; +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// cumsum_before[pos] = acc; +// acc += sorted_values_cat[pos]; +// } + +// // --- 8) 构造 index 掩码: index[pos] = index_pairs[pos].second if cumsum_before[pos] < required_sum else 0 +// // 然后把 index[pos] 对应的 k 插入 retained(等价于 python 的 fancy assignment) +// // 先强制包含 (align with original C++) +// retained.insert({q_block_idx, 0}); +// if (diagonal_k < k_block_num) +// retained.insert({q_block_idx, diagonal_k}); + +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// if (cumsum_before[pos] < required_sum) { +// size_t sel_k = index_pairs[pos].second; +// retained.insert({q_block_idx, sel_k}); +// } else { +// // python uses 0 where mask false; but we already inserted 0 above +// } +// } + +// // --- Note: we intentionally do NOT add any ad-hoc "neighbor extension" here. +// // The above faithfully reproduces Python's selection (including the "[:-2]" trimming). +// // Debug printing (commented): +// if (head_idx == 0 && (q_block_idx == 6 || q_block_idx == 7)) { +// std::cout << "[DBG] q=" << q_block_idx +// << " row_sum=" << row_sum +// << " required=" << required_sum +// << " forced_sum=" << forced_sum +// << " cumsum_before(last)=" << cumsum_before.back() +// << " retained_count=" << retained.size() << std::endl; +// std::cout << " index_order: "; +// for (size_t i = 0; i < index_pairs.size(); ++i) std::cout << index_pairs[i].second << " "; +// std::cout << std::endl; +// std::cout << " sorted_values_cat: "; +// for (size_t i = 0; i < sorted_values_cat.size(); ++i) std::cout << sorted_values_cat[i] << " "; +// std::cout << std::endl; +// } +// } // q_block loop + +// // --- Enforce causal (keep original style/condition) +// auto it = retained.begin(); +// while (it != retained.end()) { +// size_t q = it->first; +// size_t k = it->second; +// if (k >= current_index && (k - current_index) > q) { +// it = retained.erase(it); +// } else { +// ++it; +// } // } -// std::cout << std::flush; -// } +// } // head loop + +// return retval; // } +// template // XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( -// const T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); -// // [num_heads, num_blocks_in_query, num_blocks_in_key] +// const T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); + // auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// struct IndexAndScore { -// XAttentionBlockIndex idx; -// T score; -// }; - -// const size_t num_heads = blocked_attention_scores_shape[0]; -// const size_t num_q_blocks = blocked_attention_scores_shape[1]; -// const size_t num_k_blocks = blocked_attention_scores_shape[2]; -// print_blocked_attention_scores(blocked_attention_scores_data, num_heads, num_q_blocks, num_k_blocks); + +// size_t num_heads = blocked_attention_scores_shape[0]; +// size_t q_block_num = blocked_attention_scores_shape[1]; +// size_t k_block_num = blocked_attention_scores_shape[2]; + +// // 当前索引保持与原始 C++ 一致,匹配 Python caller +// size_t current_index = k_block_num - q_block_num; + +// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { +// auto& retained = retval[head_idx]; + +// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { +// // --- 1) 读一行(q_block_idx)并计算 row_sum +// std::vector row(k_block_num); +// double row_sum = 0.0; +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// double v = static_cast(blocked_attention_scores_data[offset]); +// if (std::isnan(v) || std::isinf(v)) +// v = 0.0; +// row[k_block_idx] = v; +// row_sum += v; +// } + +// double required_sum = m_threshold * row_sum; + +// // --- 2) 构造 forced mask(k==0 与 diagonal_k) +// std::vector forced(k_block_num, 0); +// forced[0] = 1; +// size_t diagonal_k = current_index + q_block_idx; +// if (diagonal_k < k_block_num) +// forced[diagonal_k] = 1; + +// // --- 3) 计算 forced_sum +// double forced_sum = 0.0; +// for (size_t k = 0; k < k_block_num; ++k) +// if (forced[k]) +// forced_sum += row[k]; + +// // --- 4) 构造 other_values = masked_fill(mask,0) 并降序排序 +// std::vector> other_pairs; +// other_pairs.reserve(k_block_num); +// for (size_t k = 0; k < k_block_num; ++k) { +// double val = forced[k] ? 0.0 : row[k]; +// other_pairs.emplace_back(val, k); +// } +// std::sort(other_pairs.begin(), other_pairs.end(), [](const auto& a, const auto& b) { +// return a.first > b.first; +// }); + +// // --- 5) 构造 sorted_values_cat +// std::vector sorted_values_cat; +// sorted_values_cat.reserve(k_block_num); +// sorted_values_cat.push_back(0.0); +// sorted_values_cat.push_back(forced_sum); +// size_t take = (k_block_num >= 2) ? (k_block_num - 2) : 0; +// for (size_t i = 0; i < take; ++i) { +// sorted_values_cat.push_back(other_pairs[i].first); +// } +// while (sorted_values_cat.size() < k_block_num) +// sorted_values_cat.push_back(0.0); + +// // --- 6) 构造 index_order +// std::vector> index_pairs; +// index_pairs.reserve(k_block_num); +// const double BIG = 100000.0; +// for (size_t k = 0; k < k_block_num; ++k) { +// double key = forced[k] ? (BIG * (1.0 + row[k])) : row[k]; +// index_pairs.emplace_back(key, k); +// } +// std::sort(index_pairs.begin(), index_pairs.end(), [](const auto& a, const auto& b) { +// return a.first > b.first; +// }); + +// // --- 7) 构造 cumulative_sum_without_self +// std::vector cumsum_before(k_block_num, 0.0); +// double acc = 0.0; +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// cumsum_before[pos] = acc; +// acc += sorted_values_cat[pos]; +// } + +// // // --- 8) 累加保留逻辑,严格对应 Python +// // retained.insert({q_block_idx, 0}); +// // if (diagonal_k < k_block_num) +// // retained.insert({q_block_idx, diagonal_k}); + +// // for (size_t pos = 0; pos < k_block_num; ++pos) { +// // if (cumsum_before[pos] < required_sum) { +// // size_t sel_k = index_pairs[pos].second; +// // retained.insert({q_block_idx, sel_k}); +// // } else { +// // break; // <-- 关键修改,停止累加,避免多保留 (7,6) +// // } +// // } + +// // --- 8) 累加保留逻辑,严格对应 Python +// retained.insert({q_block_idx, 0}); +// if (diagonal_k < k_block_num) +// retained.insert({q_block_idx, diagonal_k}); + +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// size_t sel_k = index_pairs[pos].second; +// if (!forced[sel_k] && cumsum_before[pos] >= required_sum) { +// // Python 对应 torch.where(index_mask, index, 0) +// continue; // 不保留非强制位置 +// } +// retained.insert({q_block_idx, sel_k}); +// } + + + +// // --- debug 打印(可注释) +// /* +// if (head_idx == 0 && (q_block_idx == 6 || q_block_idx == 7)) { +// std::cout << "[DBG] q=" << q_block_idx +// << " row_sum=" << row_sum +// << " required=" << required_sum +// << " forced_sum=" << forced_sum +// << " cumsum_before(last)=" << cumsum_before.back() +// << " retained_count=" << retained.size() << std::endl; +// std::cout << " index_order: "; +// for (size_t i = 0; i < index_pairs.size(); ++i) std::cout << index_pairs[i].second << " "; +// std::cout << std::endl; +// std::cout << " sorted_values_cat: "; +// for (size_t i = 0; i < sorted_values_cat.size(); ++i) std::cout << sorted_values_cat[i] << " "; +// std::cout << std::endl; +// } +// */ +// } + +// // --- Enforce causal +// auto it = retained.begin(); +// while (it != retained.end()) { +// size_t q = it->first; +// size_t k = it->second; +// if (k >= current_index && (k - current_index) > q) { +// it = retained.erase(it); +// } else { +// ++it; +// } +// } +// } + +// return retval; +// } + +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); + +// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + +// size_t num_heads = blocked_attention_scores_shape[0]; +// size_t q_block_num = blocked_attention_scores_shape[1]; +// size_t k_block_num = blocked_attention_scores_shape[2]; + +// // 与 Python 对齐 +// size_t current_index = k_block_num - q_block_num; + +// for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { +// auto& retained = retval[head_idx]; + +// for (size_t q_block_idx = 0; q_block_idx < q_block_num; ++q_block_idx) { +// // --- 1) 读取一行 +// std::vector row(k_block_num); +// double row_sum = 0.0; +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// double v = static_cast(blocked_attention_scores_data[offset]); +// if (std::isnan(v) || std::isinf(v)) v = 0.0; +// row[k_block_idx] = v; +// row_sum += v; +// } + +// double required_sum = m_threshold * row_sum; + +// // --- 2) 强制保留位置 +// std::vector forced(k_block_num, 0); +// forced[0] = 1; +// size_t diagonal_k = current_index + q_block_idx; +// if (diagonal_k < k_block_num) forced[diagonal_k] = 1; + +// double forced_sum = 0.0; +// for (size_t k = 0; k < k_block_num; ++k) +// if (forced[k]) forced_sum += row[k]; + +// // --- 3) 其他值排序 +// std::vector> other_pairs; // (value, k_idx) +// for (size_t k = 0; k < k_block_num; ++k) +// other_pairs.emplace_back(forced[k] ? 0.0 : row[k], k); +// std::sort(other_pairs.begin(), other_pairs.end(), [](const auto& a, const auto& b) { +// return a.first > b.first; +// }); + +// // --- 4) 构造 sorted_values_cat +// std::vector sorted_values_cat; +// sorted_values_cat.push_back(0.0); +// sorted_values_cat.push_back(forced_sum); +// size_t take = k_block_num >= 2 ? k_block_num - 2 : 0; +// for (size_t i = 0; i < take; ++i) sorted_values_cat.push_back(other_pairs[i].first); +// while (sorted_values_cat.size() < k_block_num) sorted_values_cat.push_back(0.0); + +// // --- 5) 构造 index_pairs (argsort desc) +// std::vector> index_pairs; +// const double BIG = 100000.0; +// for (size_t k = 0; k < k_block_num; ++k) +// index_pairs.emplace_back(forced[k] ? (BIG * (1.0 + row[k])) : row[k], k); +// std::sort(index_pairs.begin(), index_pairs.end(), [](const auto& a, const auto& b) { +// return a.first > b.first; +// }); + +// // --- 6) cumsum_before +// std::vector cumsum_before(k_block_num, 0.0); +// double acc = 0.0; +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// cumsum_before[pos] = acc; +// acc += sorted_values_cat[pos]; +// } + +// // --- 7) 强制保留 +// retained.insert({q_block_idx, 0}); +// if (diagonal_k < k_block_num) retained.insert({q_block_idx, diagonal_k}); + +// // --- 8) 按 Python 逻辑选择 +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// if (cumsum_before[pos] < required_sum) { +// size_t sel_k = index_pairs[pos].second; +// retained.insert({q_block_idx, sel_k}); +// } +// } + +// // --- 9) 完整 debug 打印 +// std::cout << "[DBG] head=" << head_idx << " q=" << q_block_idx +// << " row_sum=" << row_sum +// << " required=" << required_sum +// << " forced_sum=" << forced_sum +// << " cumsum_before(last)=" << cumsum_before.back() +// << " retained_count=" << retained.size() << std::endl; + +// std::cout << " row: "; +// for (auto v : row) std::cout << v << " "; +// std::cout << std::endl; + +// std::cout << " forced: "; +// for (auto f : forced) std::cout << (int)f << " "; +// std::cout << std::endl; + +// std::cout << " other_pairs: "; +// for (auto& p : other_pairs) std::cout << "(" << p.first << "," << p.second << ") "; +// std::cout << std::endl; + +// std::cout << " sorted_values_cat: "; +// for (auto v : sorted_values_cat) std::cout << v << " "; +// std::cout << std::endl; + +// std::cout << " index_pairs: "; +// for (auto& p : index_pairs) std::cout << "(" << p.first << "," << p.second << ") "; +// std::cout << std::endl; + +// std::cout << " cumsum_before: "; +// for (auto v : cumsum_before) std::cout << v << " "; +// std::cout << std::endl; + +// std::cout << " retained before causal: "; +// for (auto& p : retained) std::cout << "(" << p.first << "," << p.second << ") "; +// std::cout << std::endl; +// } // q_block loop + +// // --- 10) enforce causal +// auto it = retained.begin(); +// while (it != retained.end()) { +// size_t q = it->first; +// size_t k = it->second; +// if (k >= current_index && (k - current_index) > q) +// it = retained.erase(it); +// else +// ++it; +// } + +// // --- 11) 打印 causal 后 retained +// std::cout << "[DBG] head=" << head_idx << " retained after causal: "; +// for (auto& p : retained) std::cout << "(" << p.first << "," << p.second << ") "; +// std::cout << std::endl; +// } // head loop + +// return retval; +// } + + +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); + +// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + +// size_t num_heads = blocked_attention_scores_shape[0]; +// size_t q_block_num = blocked_attention_scores_shape[1]; +// size_t k_block_num = blocked_attention_scores_shape[2]; + +// size_t current_index = k_block_num - q_block_num; // Python caller behavior + +// const double BIG = 100000.0; // for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { -// size_t head_offset = head_idx * num_q_blocks * num_k_blocks; +// auto& retained = retval[head_idx]; + +// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { +// // --- 1) row +// std::vector row(k_block_num); +// double row_sum = 0.0; +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// double v = static_cast(blocked_attention_scores_data[offset]); +// if (std::isnan(v) || std::isinf(v)) v = 0.0; +// row[k_block_idx] = v; +// row_sum += v; +// } +// double required_sum = m_threshold * row_sum; -// for (size_t q_block_idx = 0; q_block_idx < num_q_blocks; q_block_idx++) { -// std::vector indices_and_scores; -// indices_and_scores.reserve(num_k_blocks); +// // --- 2) forced mask +// std::vector forced(k_block_num, 0); +// forced[0] = 1; +// size_t diagonal_k = current_index + q_block_idx; +// if (diagonal_k < k_block_num) +// forced[diagonal_k] = 1; -// double total_sum = 0.0; +// // --- 3) forced sum +// double forced_sum = 0.0; +// for (size_t k = 0; k < k_block_num; ++k) +// if (forced[k]) forced_sum += row[k]; -// for (size_t k_block_idx = 0; k_block_idx < num_k_blocks; k_block_idx++) { -// size_t target_offset = head_offset + q_block_idx * num_k_blocks + k_block_idx; -// T current_score = *(blocked_attention_scores_data + target_offset); -// indices_and_scores.push_back({{q_block_idx, k_block_idx}, current_score}); -// total_sum += current_score; +// // --- 4) other values +// std::vector> other_pairs; // value, k +// for (size_t k = 0; k < k_block_num; ++k) { +// if (!forced[k]) other_pairs.emplace_back(row[k], k); // } +// std::sort(other_pairs.begin(), other_pairs.end(), +// [](const auto& a, const auto& b) { return a.first > b.first; }); -// double required_sum = m_threshold * total_sum; +// // --- 5) sorted_values_cat +// std::vector sorted_values_cat; +// sorted_values_cat.push_back(0.0); +// sorted_values_cat.push_back(forced_sum); +// size_t take_count = (other_pairs.size() >= 2) ? other_pairs.size() - 2 : other_pairs.size(); +// for (size_t i = 0; i < take_count; ++i) sorted_values_cat.push_back(other_pairs[i].first); +// while (sorted_values_cat.size() < k_block_num) sorted_values_cat.push_back(0.0); -// std::sort(indices_and_scores.begin(), indices_and_scores.end(), -// [](const IndexAndScore& a, const IndexAndScore& b) { -// return a.score > b.score; -// }); +// // --- 6) index pairs (argsort) +// std::vector> index_pairs; +// for (size_t k = 0; k < k_block_num; ++k) { +// double key = forced[k] ? BIG * (1.0 + row[k]) : row[k]; +// index_pairs.emplace_back(key, k); +// } +// std::sort(index_pairs.begin(), index_pairs.end(), +// [](const auto& a, const auto& b) { return a.first > b.first; }); -// std::vector shifted_cumsum(num_k_blocks, 0.0); -// for (size_t i = 1; i < num_k_blocks; i++) { -// shifted_cumsum[i] = shifted_cumsum[i - 1] + indices_and_scores[i - 1].score; +// // --- 7) cumsum_before +// std::vector cumsum_before(k_block_num, 0.0); +// double acc = 0.0; +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// cumsum_before[pos] = acc; +// acc += sorted_values_cat[pos]; // } -// for (size_t i = 0; i < num_k_blocks; i++) { -// if (shifted_cumsum[i] < required_sum) { -// retval[head_idx].insert(indices_and_scores[i].idx); +// // --- 8) insert into retained +// // force include 0 and diagonal +// retained.insert({q_block_idx, 0}); +// if (diagonal_k < k_block_num) retained.insert({q_block_idx, diagonal_k}); + +// for (size_t pos = 0; pos < k_block_num; ++pos) { +// if (cumsum_before[pos] < required_sum) { +// size_t sel_k = index_pairs[pos].second; +// retained.insert({q_block_idx, sel_k}); // } // } + +// // --- debug print +// std::cout << "[DBG] head=" << head_idx << " q=" << q_block_idx +// << " row_sum=" << row_sum +// << " required=" << required_sum +// << " forced_sum=" << forced_sum +// << " cumsum_before(last)=" << cumsum_before.back() +// << " retained_count=" << retained.size() << "\n"; +// std::cout << " row: "; +// for (auto v : row) std::cout << v << " "; +// std::cout << "\n forced: "; +// for (auto f : forced) std::cout << int(f) << " "; +// std::cout << "\n other_pairs: "; +// for (auto& p : other_pairs) std::cout << "(" << p.first << "," << p.second << ") "; +// std::cout << "\n sorted_values_cat: "; +// for (auto v : sorted_values_cat) std::cout << v << " "; +// std::cout << "\n index_pairs: "; +// for (auto& p : index_pairs) std::cout << "(" << p.first << "," << p.second << ") "; +// std::cout << "\n cumsum_before: "; +// for (auto v : cumsum_before) std::cout << v << " "; +// std::cout << "\n retained before causal: "; +// for (auto& x : retained) std::cout << "(" << x.first << "," << x.second << ") "; +// std::cout << "\n"; +// } + +// // --- 9) causal mask +// auto it = retained.begin(); +// while (it != retained.end()) { +// size_t q = it->first; +// size_t k = it->second; +// if (k >= current_index && (k - current_index) > q) { +// it = retained.erase(it); +// } else { +// ++it; +// } +// } + +// // --- debug retained after causal +// std::cout << "[DBG] head=" << head_idx << " retained after causal: "; +// for (auto& x : retained) std::cout << "(" << x.first << "," << x.second << ") "; +// std::cout << "\n"; +// } + +// return retval; +// } + +void print_blocked_attention_scores(const T* blocked_attention_scores_data, + size_t num_heads, + size_t q_block_num, + size_t k_block_num) { + std::cout << "=== blocked_attention_scores_data ===\n"; + for (size_t h = 0; h < num_heads; ++h) { + std::cout << "Head " << h << ":\n"; + for (size_t q = 0; q < q_block_num; ++q) { + std::cout << " q_block " << q << ": "; + for (size_t k = 0; k < k_block_num; ++k) { + size_t offset = h * q_block_num * k_block_num + q * k_block_num + k; + std::cout << std::fixed << std::setprecision(6) + << blocked_attention_scores_data[offset] << " "; + } + std::cout << "\n"; + } + std::cout << "\n"; + } +} + +void print_retained_blocks(const XAttentionRetainedBlockIndicesForAllHeads& retained_blocks) { + for (size_t head = 0; head < retained_blocks.size(); ++head) { + std::cout << "[Head " << head << "] retained blocks: "; + for (const auto& p : retained_blocks[head]) { + std::cout << "(" << p.first << "," << p.second << ") "; + } + std::cout << std::endl; + } +} + +void print_scores(const std::vector>& scores) { + std::cout << "[Scores] "; + for (const auto& p : scores) { + std::cout << "(" << p.first << ", " << p.second << ") "; + } + std::cout << std::endl; +} + + + +// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( +// T* blocked_attention_scores_data, +// const Shape& blocked_attention_scores_shape) { +// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); + +// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + +// size_t num_heads = blocked_attention_scores_shape[0]; +// size_t q_block_num = blocked_attention_scores_shape[1]; +// size_t k_block_num = blocked_attention_scores_shape[2]; + +// print_blocked_attention_scores(blocked_attention_scores_data, +// num_heads, q_block_num, k_block_num); + + +// float blocked_attention_scores_values[q_block_num * k_block_num] = { +// 2.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, +// 1.1399f, 0.8601f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, +// 0.5426f, 0.8147f, 0.6427f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, +// 0.4169f, 0.5852f, 0.6589f, 0.3390f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, +// 0.5131f, 0.4026f, 0.4603f, 0.3615f, 0.2625f, 0.0000f, 0.0000f, 0.0000f, +// 0.3882f, 0.3218f, 0.3278f, 0.3583f, 0.3449f, 0.2589f, 0.0000f, 0.0000f, +// 0.3030f, 0.3146f, 0.2382f, 0.3002f, 0.2992f, 0.3479f, 0.1969f, 0.0000f, +// 0.2431f, 0.3503f, 0.3054f, 0.2146f, 0.2261f, 0.2692f, 0.1847f, 0.2065f +// }; + +// // 分配可写的 ov::float16 buffer +// // ov::float16* blocked_attention_scores_data = new ov::float16[num_heads * q_block_num * k_block_num]; + +// // 逐元素赋值 +// for (int i = 0; i < 64; ++i) { +// blocked_attention_scores_data[i] = ov::float16(blocked_attention_scores_values[i]); +// } + +// print_blocked_attention_scores(blocked_attention_scores_data, +// num_heads, q_block_num, k_block_num); + +// // ✅ Python 中没有 current_index 偏移的逻辑 +// // 原逻辑引入 offset 导致 diagonal 错位 +// // 如果确实需要 offset,可通过参数控制,但这里保持与 Python 一致 +// // size_t current_index = 0; + +// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { +// auto& retained = retval[head_idx]; + +// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { +// std::cout << "**************************\n"; +// // 1️⃣ 累加整行分数 +// double row_sum = 0.0; +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// row_sum += static_cast(blocked_attention_scores_data[offset]); +// } + +// double required_sum = m_threshold * row_sum; +// std::cout << "required_sum: " << required_sum << std::endl; +// double cumsum = 0.0; + +// // // 2️⃣ 强制保留 diagonal 块 +// // size_t diagonal_k = q_block_idx; +// // size_t offset_diag = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; +// // double diag_score = static_cast(blocked_attention_scores_data[offset_diag]); +// // std::cout << "diag_score: " << diag_score << std::endl; +// // cumsum += diag_score; +// // retained.insert({q_block_idx, diagonal_k}); + +// // print_retained_blocks(retval); + +// // // 3️⃣ 收集所有候选块 +// // std::vector> scores; +// // scores.reserve(k_block_num); +// // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { +// // if (k_block_idx == diagonal_k) +// // continue; +// // if (k_block_idx == 0) continue; +// // size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// // scores.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); +// // } + +// // print_scores(scores); + +// // // 4️⃣ 降序排序(高分优先) +// // std::sort(scores.begin(), scores.end(), +// // [](const auto& a, const auto& b) { return a.first > b.first; }); + +// // // 5️⃣ 从高到低选取直到累积超过阈值 +// // for (auto& [score, k_block_idx] : scores) { +// // if (cumsum >= required_sum) +// // break; +// // cumsum += score; +// // retained.insert({q_block_idx, k_block_idx}); +// // } + + +// // 2️⃣ 强制保留 diagonal 块 +// size_t diagonal_k = q_block_idx; +// size_t offset_diag = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; +// double diag_score = static_cast(blocked_attention_scores_data[offset_diag]); +// cumsum += diag_score; +// retained.insert({q_block_idx, diagonal_k}); + +// // 2️⃣.1️⃣ 额外:强制保留首列块 (k=0),与 Python mask[:, :, :, 0] = 1 一致 +// if (k_block_num > 0 && q_block_idx != 0) { +// size_t offset_first = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + 0; +// double first_col_score = static_cast(blocked_attention_scores_data[offset_first]); +// cumsum += first_col_score; +// retained.insert({q_block_idx, 0}); +// } + +// // 3️⃣ 收集其他候选块(去掉 diagonal 和首列) +// std::vector> scores; +// scores.reserve(k_block_num); +// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { +// if (k_block_idx == diagonal_k || k_block_idx == 0) +// continue; +// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; +// scores.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); +// } + +// // 4️⃣ 降序排序(高分优先) +// std::sort(scores.begin(), scores.end(), +// [](const auto& a, const auto& b) { return a.first > b.first; }); + +// // 5️⃣ 从高到低选取直到累积超过阈值 +// for (auto& [score, k_block_idx] : scores) { +// if (cumsum >= required_sum) +// break; +// cumsum += score; +// retained.insert({q_block_idx, k_block_idx}); +// } + + +// // 6️⃣ 保证左侧(k <= q)邻域不被裁掉 +// // (Python 行为是保留对角线及左侧邻近块) +// for (int s = 1; s <= 2; s++) { // stride=2 可根据外部参数替换 +// if (q_block_idx >= static_cast(s)) +// retained.insert({q_block_idx, q_block_idx - s}); +// } + +// // 7️⃣ 保证对角块右邻域(但受 causal 约束) +// for (int s = 1; s <= 2; s++) { +// size_t right = q_block_idx + s; +// if (right < k_block_num) +// retained.insert({q_block_idx, right}); +// } + +// // 调试打印(默认注释) +// // std::cout << "[Head " << head_idx << "] Q=" << q_block_idx +// // << " required_sum=" << required_sum << " cumsum=" << cumsum +// // << " diag_score=" << diag_score << " retained=" << retained.size() +// // << std::endl; // } + +// // 8️⃣ 修正 causal mask(与 Python 一致:禁止未来块) +// auto it = retained.begin(); +// while (it != retained.end()) { +// size_t q = it->first; +// size_t k = it->second; +// if (k > q) { // ✅ Python 中严格排除未来块 +// it = retained.erase(it); +// } else { +// ++it; +// } +// } + +// // 调试打印(默认注释) +// // std::cout << "Head " << head_idx << " selected blocks:"; +// // for (auto [a, b] : retained) +// // std::cout << " (" << a << "," << b << ")"; +// // std::cout << std::endl; // } // return retval; // } - XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, - const Shape& blocked_attention_scores_shape) { - OPENVINO_ASSERT(blocked_attention_scores_shape.size() == - 3); // [num_heads, num_blocks_in_query, num_blocks_in_key] +XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( + T* blocked_attention_scores_data, + const Shape& blocked_attention_scores_shape) { + + OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, + "Expected shape [num_heads, q_block_num, k_block_num]"); - auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); + size_t num_heads = blocked_attention_scores_shape[0]; + size_t q_block_num = blocked_attention_scores_shape[1]; + size_t k_block_num = blocked_attention_scores_shape[2]; - struct IndexAndScore { - XAttentionBlockIndex idx; - T score; - bool operator<(const IndexAndScore& rhs) const { - return score < rhs.score; + // float blocked_attention_scores_values[q_block_num * k_block_num] = { + // 2.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + // 1.1399f, 0.8601f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + // 0.5426f, 0.8147f, 0.6427f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + // 0.4169f, 0.5852f, 0.6589f, 0.3390f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, + // 0.5131f, 0.4026f, 0.4603f, 0.3615f, 0.2625f, 0.0000f, 0.0000f, 0.0000f, + // 0.3882f, 0.3218f, 0.3278f, 0.3583f, 0.3449f, 0.2589f, 0.0000f, 0.0000f, + // 0.3030f, 0.3146f, 0.2382f, 0.3002f, 0.2992f, 0.3479f, 0.1969f, 0.0000f, + // 0.2431f, 0.3503f, 0.3054f, 0.2146f, 0.2261f, 0.2692f, 0.1847f, 0.2065f + // }; + + // // 分配可写的 ov::float16 buffer + // // ov::float16* blocked_attention_scores_data = new ov::float16[num_heads * q_block_num * k_block_num]; + + // // 逐元素赋值 + // for (int i = 0; i < 64; ++i) { + // blocked_attention_scores_data[i] = ov::float16(blocked_attention_scores_values[i]); + // } + + // std::vector blocked_attention_scores_f32(num_heads * q_block_num * k_block_num); + // for (size_t i = 0; i < blocked_attention_scores_f32.size(); ++i) { + // blocked_attention_scores_f32[i] = static_cast(blocked_attention_scores_data[i]); + // } + + // print_blocked_attention_scores(blocked_attention_scores_data, + // num_heads, q_block_num, k_block_num); + + // 返回结果,每个 head 一个 set 存储 (q_block_idx, k_block_idx) + XAttentionRetainedBlockIndicesForAllHeads retval(num_heads); + + // 临时 mask 矩阵,用于模拟 Python mask + std::vector>> mask( + num_heads, std::vector>( + q_block_num, std::vector(k_block_num, false))); + + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { + for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { + // Step0: diagonal 保留 + size_t diagonal_k = q_block_idx; + if (diagonal_k < k_block_num) { + mask[head_idx][q_block_idx][diagonal_k] = true; + } + // Step1: 首列保留 + mask[head_idx][q_block_idx][0] = true; + + // Step2: 构建 other_values(masked_fill) + std::vector> other_values; + for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + if (mask[head_idx][q_block_idx][k_block_idx]) + continue; + size_t offset = head_idx * q_block_num * k_block_num + + q_block_idx * k_block_num + + k_block_idx; + other_values.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); + } + + // // Step4: 打印 other_values + // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] other_values:\n"; + // for (auto& [score, k_block_idx] : other_values) { + // std::cout << "(" << k_block_idx << ", " << score << ") "; + // } + // std::cout << std::endl; + + // Step3: 对 other_values 降序排序 + std::sort(other_values.begin(), other_values.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + + // Step4: 构建 cumulative_sum_without_self,cat([0, diagonal_sum, sorted_values[:-1]]) + std::vector sorted_scores; + sorted_scores.push_back(0.0); // 前置0 + // diagonal + 首列分数 + size_t offset_diag = head_idx * q_block_num * k_block_num + + q_block_idx * k_block_num + + diagonal_k; + float diag_score = static_cast(blocked_attention_scores_data[offset_diag]); + float first_col_score = 0.0; + if (diagonal_k != 0) { + size_t offset_first = head_idx * q_block_num * k_block_num + + q_block_idx * k_block_num + + 0; + first_col_score = static_cast(blocked_attention_scores_data[offset_first]); + } + std::cout << diag_score << " " << diag_score << " " << first_col_score << " " << diag_score + first_col_score << std::endl; + sorted_scores.push_back(diag_score + first_col_score); + + // for (size_t i = 0; i + 1 < other_values.size(); i++) { + // sorted_scores.push_back(other_values[i].first); + // } + for (auto& p : other_values) { + sorted_scores.push_back(p.first); + } + if (q_block_idx == 0) { + sorted_scores.pop_back(); + } + // // Step4.1: 打印 sorted_scores + // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] sorted_scores: "; + // for (size_t i = 0; i < sorted_scores.size(); i++) { + // std::cout << sorted_scores[i] << " "; + // } + // std::cout << std::endl; + + + + // Step5: 计算 cumsum_without_self: cumsum of right-shifted sorted_scores + std::vector cumsum_without_self(sorted_scores.size(), 0.0); + float running = 0.0; + for (size_t i = 0; i < sorted_scores.size(); ++i) { + cumsum_without_self[i] = running; // 等价于 Python 的 cat([0, ...]) then cumsum, i.e. previous sum + running += sorted_scores[i]; + } + + // // 打印 cumsum_without_self(调试用) + // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] cumsum: "; + // for (size_t i = 0; i < cumsum_without_self.size(); i++) { + // std::cout << cumsum_without_self[i] << " "; + // } + // std::cout << std::endl; + + // Step6: 生成 required_sum(基于整行) + size_t offset_row_start = head_idx * q_block_num * k_block_num + + q_block_idx * k_block_num; + float row_sum = 0.0; + for (size_t k = 0; k < k_block_num; k++) { + row_sum += static_cast(blocked_attention_scores_data[offset_row_start + k]); + } + float required_sum = row_sum * m_threshold; + std::cout << "required_sum: " << required_sum << std::endl; + + + // Step7: 构建 index_mask + std::vector index_mask(cumsum_without_self.size(), false); + for (size_t i = 0; i < cumsum_without_self.size(); i++) { + index_mask[i] = (cumsum_without_self[i] < required_sum); + } + + // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] index_mask: "; + // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + // std::cout << (index_mask[head_idx][q_block_idx][k_block_idx] ? "1 " : "0 "); + // } + // std::cout << std::endl; + + + // Step8: 构建 index 向量(torch.where(index_mask, index, 0)) + std::vector index(index_mask.size(), 0); + for (size_t i = 0; i < index_mask.size(); i++) { + if (index_mask[i]) { + // 索引来源:sorted_scores[0], [1], ... 对应哪些 k_block? + // 前两个为 [0:padding], [1:diag+col0], 后续对应 other_values + if (i == 0) index[i] = 0; // dummy + else if (i == 1) index[i] = diagonal_k; + else if (i - 2 < other_values.size()) + index[i] = other_values[i - 2].second; + else + index[i] = 0; + } } - }; - - for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) { - size_t head_offset = head_idx * blocked_attention_scores_shape[1] * blocked_attention_scores_shape[2]; - std::priority_queue indices_and_scores_queue; - double total_sum = 0.0; - for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) { - for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) { - size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx; - T current_score = *(blocked_attention_scores_data + target_offset); - indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score}); - total_sum += current_score; + + // Step9: 模拟 Python mask[:, torch.arange(...), index] = True + // 即对每个 (head_idx, q_block_idx),将 index[i] 对应的 k_block 置 True + for (size_t i = 0; i < index.size(); i++) { + size_t k_block_idx = index[i]; + if (index_mask[i] && k_block_idx < k_block_num) { + mask[head_idx][q_block_idx][k_block_idx] = true; } } - double cumsum = 0.0; - double required_sum = m_threshold * total_sum; - while (cumsum < required_sum && !indices_and_scores_queue.empty()) { - auto index_and_largest_score = indices_and_scores_queue.top(); - indices_and_scores_queue.pop(); - cumsum += index_and_largest_score.score; - retval[head_idx].insert(index_and_largest_score.idx); + + + // 打印 cumsum_without_self(调试用) + std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] required_sum: " << required_sum << std::endl; + + // Step7: 根据 index_mask 更新 mask + // 注意:sorted_scores 带有两个前缀项,因此 other_values 对应的 sorted_scores 索引从 2 开始 + // but we must only iterate the number of other_values actually included in sorted_scores. + // size_t included_count = 0; + // if (sorted_scores.size() > 2) { + // included_count = sorted_scores.size() - 2; + // } else { + // included_count = 0; + // } + + + // // 🔹 Step10.1: 打印当前 head、q_block 的 mask + // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] mask: "; + // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + // std::cout << (mask[head_idx][q_block_idx][k_block_idx] ? "1 " : "0 "); + // } + // std::cout << std::endl; + + // for (size_t i = 0; i < included_count; ++i) { + // size_t idx_in_sorted = 2 + i; + // // 安全检查(通常应该不越界) + // if (idx_in_sorted < cumsum_without_self.size()) { + // if (cumsum_without_self[idx_in_sorted] < required_sum) { + // size_t k_block_idx = other_values[i].second; + // mask[head_idx][q_block_idx][k_block_idx] = true; + // } + // } else { + // // 如果发生越界,输出调试信息(不抛异常以便继续调试) + // std::cerr << "Debug: idx_in_sorted out of range: " << idx_in_sorted + // << " cumsum_size=" << cumsum_without_self.size() + // << " other_values.size()=" << other_values.size() + // << " sorted_scores.size()=" << sorted_scores.size() << std::endl; + // } + // } + + // // Step8: 保留左侧邻域(stride=2) + // for (int s = 1; s <= 2; s++) { + // if (q_block_idx >= static_cast(s)) { + // std::cout << head_idx << " " << q_block_idx << " " << q_block_idx - s << std::endl; + // mask[head_idx][q_block_idx][q_block_idx - s] = true; + // } + // } + + // // Step9: 保留右侧邻域(受 causal 约束) + // for (int s = 1; s <= 2; s++) { + // size_t right = q_block_idx + s; + // if (right < k_block_num) { + // std::cout << head_idx << " " << q_block_idx << " " << right << std::endl; + // mask[head_idx][q_block_idx][right] = true; + // } + // } + + // // Step10: causal mask,删除未来块 + // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + // if (k_block_idx > q_block_idx) + // mask[head_idx][q_block_idx][k_block_idx] = false; + // } + + // 🔹 Step10.1: 打印当前 head、q_block 的 mask + // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] mask: "; + // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + // std::cout << (mask[head_idx][q_block_idx][k_block_idx] ? "1 " : "0 "); + // } + // std::cout << std::endl; + + // Step11: 收集 mask 为 true 的块到 retval + for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + if (mask[head_idx][q_block_idx][k_block_idx]) + retval[head_idx].insert({q_block_idx, k_block_idx}); + } + } + } + + return retval; +} + +void print_attn_score_buf_with_shape(const std::shared_ptr& buf, + size_t num_heads, + size_t rows, // 实际 buf 的第2维长度 + size_t cols, // 实际 buf 的第3维长度 + size_t show_first_n_cols = 0) { // 0 表示显示全部 + std::cout << "=== Debug: attn_score_buf (shape = [" << num_heads << ", " << rows << ", " << cols << "]) ===\n"; + for (size_t h = 0; h < num_heads; ++h) { + std::cout << "Head " << h << ":\n"; + for (size_t r = 0; r < rows; ++r) { + std::cout << std::setw(3) << r << ": "; + size_t nonzero_count = 0; + size_t limit = (show_first_n_cols == 0) ? cols : std::min(cols, (size_t)show_first_n_cols); + for (size_t c = 0; c < limit; ++c) { + size_t idx = h * rows * cols + r * cols + c; + double v = static_cast(buf[idx]); + if (std::fabs(v) > 1e-12) ++nonzero_count; + std::cout << std::fixed << std::setprecision(6) << v << " "; + } + if (limit < cols) std::cout << "..."; + std::cout << " (nonzero=" << nonzero_count << ")\n"; + } + // 打印非零掩码行(帮助看 pattern) + std::cout << "Nonzero mask per row: "; + for (size_t r = 0; r < rows; ++r) { + size_t nonzero = 0; + for (size_t c = 0; c < cols; ++c) { + size_t idx = h * rows * cols + r * cols + c; + if (std::fabs(static_cast(buf[idx])) > 1e-12) { + nonzero = 1; + break; + } } + std::cout << nonzero; } - return retval; + std::cout << "\n\n"; } + std::cout << "=== End attn_score_buf ===\n"; +} + +void print_qk_buf(const std::shared_ptr& qk_buf, + size_t num_heads, + size_t q_block_num, + size_t k_block_num, + size_t show_first_n_cols = 0) { + std::cout << "\n=== Debug: qk_buf (shape = [" + << num_heads << ", " << q_block_num << ", " << k_block_num << "]) ===" + << std::endl; + + for (size_t h = 0; h < num_heads; ++h) { + std::cout << "Head " << h << ":\n"; + for (size_t q = 0; q < q_block_num; ++q) { + std::cout << std::setw(3) << q << ": "; + size_t limit = (show_first_n_cols == 0) + ? k_block_num + : std::min(k_block_num, (size_t)show_first_n_cols); + size_t nonzero_count = 0; + for (size_t k = 0; k < limit; ++k) { + size_t idx = h * q_block_num * k_block_num + q * k_block_num + k; + double val = static_cast(qk_buf[idx]); + if (std::fabs(val) > 1e-12) + ++nonzero_count; + std::cout << std::fixed << std::setprecision(6) << val << " "; + } + if (limit < k_block_num) + std::cout << "..."; + std::cout << " (nonzero=" << nonzero_count << ")\n"; + } + + // 打印每行是否含非零的简单掩码 + std::cout << "Nonzero mask per row: "; + for (size_t q = 0; q < q_block_num; ++q) { + bool nonzero = false; + for (size_t k = 0; k < k_block_num; ++k) { + size_t idx = h * q_block_num * k_block_num + q * k_block_num + k; + if (std::fabs(static_cast(qk_buf[idx])) > 1e-12) { + nonzero = true; + break; + } + } + std::cout << (nonzero ? "1" : "0"); + } + std::cout << "\n\n"; + } + + std::cout << "=== End of qk_buf ===\n" << std::endl; +} + +void assign_qk_buf(std::shared_ptr& qk_buf, + size_t num_heads, + size_t q_block_num, + size_t k_block_num) { + std::vector data = { + 0.1953, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.1914, 0.2695, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.2305, -0.1211, -0.1211, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + 0.0703, -0.0859, 0.2148, -0.1367, -65504.0, -65504.0, -65504.0, -65504.0, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.1367, -0.4766, -0.0039, 0.0273, 0.2031, -65504.0, -65504.0, -65504.0, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.4414, 0.0703, 0.3477, 0.4102, 0.2891, 0.4453, -65504.0, -65504.0, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.2266, -0.1797, 0.1992, 0.1523, 0.0586, 0.5234, -0.2070, -65504.0, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.3164, -0.0117, 0.0312, 0.2422, 0.3047, 0.1562, -0.1172, 0.0820, + -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + 0.4648, -0.0117, 0.1680, -0.3086, -0.2695, 0.3906, -0.1641, -0.1406, + -0.1211, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + 0.3086, -0.0156, 0.0430, -0.0938, -0.1484, 0.2773, -0.2812, 0.0039, + -0.1133, -0.2656, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + 0.5078, -0.0664, -0.2266, -0.6055, -0.2383, -0.1719, -0.0195, 0.2461, + 0.0859, -0.1680, 0.1875, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.4922, 0.4258, 0.2578, 0.4219, 0.0820, 0.3711, 0.4688, -0.5859, + -0.1328, 0.4102, -0.2266, 0.2695, -65504.0, -65504.0, -65504.0, -65504.0, + + -0.5586, 0.5430, 0.1211, 0.3359, -0.0859, -0.3477, 0.2500, 0.0391, + -0.1797, 0.5430, -0.2109, 0.7695, 0.1484, -65504.0, -65504.0, -65504.0, + + 0.0859, -0.1406, 0.0430, -0.1406, -0.0938, -0.2539, -0.0781, -0.0273, + -0.0820, -0.2578, 0.0469, -0.0781, -0.2227, -0.2969, -65504.0, -65504.0, + + -0.2109, -0.2539, 0.3086, 0.7109, 0.2695, 0.5547, -0.0977, -0.5430, + -0.1953, -0.3242, -0.1289, -0.0156, -0.0547, -0.5391, 0.1133, -65504.0, + + 0.0742, 0.1758, 0.2344, -0.1523, -0.2109, -0.0508, 0.0859, -0.1953, + -0.1562, 0.1680, 0.3242, 0.0195, -0.4141, -0.3164, -0.1133, 0.2383 + }; + + size_t total = num_heads * q_block_num * k_block_num; + if (data.size() != total) { + std::cerr << "Error: expected total=" << total << " but data.size=" << data.size() << std::endl; + return; + } + + // qk_buf = std::shared_ptr(new float[total]); + std::copy(data.begin(), data.end(), qk_buf.get()); +} + +void print_causal_mask_buf(const std::shared_ptr& causal_mask_buf, + size_t num_heads, + size_t q_block_num, + size_t k_block_num) { + std::cout << "=== Debug: causal_mask_buf ===" << std::endl; + + for (size_t h = 0; h < num_heads; ++h) { + std::cout << "Head " << h << ":\n"; + for (size_t q = 0; q < q_block_num; ++q) { + for (size_t k = 0; k < k_block_num; ++k) { + size_t idx = h * q_block_num * k_block_num + q * k_block_num + k; + auto val = static_cast(causal_mask_buf[idx]); + std::cout << std::setw(6) << val << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } + + std::cout << "=== End of causal_mask_buf ===" << std::endl; +} + +void print_q_buf(const std::shared_ptr& q_buf, + size_t num_heads, + size_t q_block_num, + size_t head_dim) { + std::cout << "=== Debug: q_buf ===" << std::endl; + + for (size_t h = 0; h < num_heads; ++h) { + std::cout << "Head " << h << ":\n"; + for (size_t q = 0; q < q_block_num; ++q) { + std::cout << "Q" << std::setw(2) << q << ": "; + for (size_t d = 0; d < head_dim; ++d) { + size_t idx = h * q_block_num * head_dim + q * head_dim + d; + auto val = static_cast(q_buf[idx]); + std::cout << std::fixed << std::setprecision(4) << std::setw(8) << val << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } + + std::cout << "=== End of q_buf ===" << std::endl; +} + +void print_k_buf(const std::shared_ptr& k_buf, + size_t num_heads, + size_t q_block_num, + size_t head_dim) { + std::cout << "=== Debug: k_buf ===" << std::endl; + + for (size_t h = 0; h < num_heads; ++h) { + std::cout << "Head " << h << ":\n"; + for (size_t q = 0; q < q_block_num; ++q) { + std::cout << "Q" << std::setw(2) << q << ": "; + for (size_t d = 0; d < head_dim; ++d) { + size_t idx = h * q_block_num * head_dim + q * head_dim + d; + auto val = static_cast(k_buf[idx]); + std::cout << std::fixed << std::setprecision(4) << std::setw(8) << val << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } + + std::cout << "=== End of q_buf ===" << std::endl; +} + +void print_query_data(const T* data, const std::vector& shape, const std::string& name = "query_data") { + if (!data) { + std::cout << name << " is nullptr\n"; + return; + } + + std::cout << "=== " << name << " ===\n"; + + if (shape.size() == 3) { // [num_heads, q_block_num, k_block_num] + size_t H = shape[0]; + size_t Q = shape[1]; + size_t K = shape[2]; + + for (size_t h = 0; h < H; ++h) { + std::cout << "Head " << h << ":\n"; + for (size_t q = 0; q < Q; ++q) { + for (size_t k = 0; k < K; ++k) { + size_t idx = h * Q * K + q * K + k; + std::cout << std::fixed << std::setprecision(4) + << static_cast(data[idx]) << " "; + } + std::cout << "\n"; + } + std::cout << "\n"; + } + } else if (shape.size() == 4) { // [B, H, Q, K] + size_t B = shape[0]; + size_t H = shape[1]; + size_t Q = shape[2]; + size_t K = shape[3]; + + for (size_t b = 0; b < B; ++b) { + std::cout << "Batch " << b << ":\n"; + for (size_t h = 0; h < H; ++h) { + std::cout << " Head " << h << ":\n"; + for (size_t q = 0; q < Q; ++q) { + std::cout << " "; + for (size_t k = 0; k < K; ++k) { + size_t idx = b * H * Q * K + h * Q * K + q * K + k; + std::cout << std::fixed << std::setprecision(4) + << static_cast(data[idx]) << " "; + } + std::cout << "\n"; + } + std::cout << "\n"; + } + } + } else { + std::cout << "Unsupported shape size=" << shape.size() << "\n"; + } + + std::cout << "=== End of " << name << " ===\n"; +} + +void set_q_buf(std::shared_ptr &q_buf) { + const size_t B = 1; + const size_t H = 1; + const size_t Q = 32; + const size_t dim = 4; + + // tmp_data 用 float 填写你的 chunked_query 数据 + float tmp_data[B*H*Q*dim] = { + -0.3750, 1.0000, -0.2500, 0.2500, -1.0000, -0.5000, -0.1250, 0.0000, + -0.6250, -0.2500, 0.7500, 0.7500, -0.2500, 0.3750, -0.3750, -0.3750, + -0.6250, -0.7500, 0.1250, 0.1250, 1.0000, 0.7500, -0.8750, 0.1250, + 0.3750, 0.8750, -0.1250, -0.2500, 1.0000, 0.7500, 0.2500, -0.2500, + 0.1250, 0.8750, -0.8750, -0.3750, 0.6250, -0.3750, -0.1250, -1.0000, + -0.3750, 0.7500, 0.0000, 0.8750, 0.7500, 0.2500, 0.6250, -0.6250, + 0.8750, -0.2500, -0.1250, 0.7500, 0.2500, 0.3750, -0.6250, -0.7500, + -0.7500, 0.0000, -0.2500, 0.6250, -1.0000, -0.5000, -0.6250, -1.0000, + 0.8750, 0.2500, 0.5000, -0.6250, -0.1250, 0.7500, -0.7500, -0.5000, + 1.0000, -0.3750, 0.6250, 0.3750, 0.2500, 0.5000, -0.5000, 0.7500, + 0.1250, 0.0000, 0.0000, -1.0000, 0.2500, 0.6250, -0.5000, 0.8750, + -0.7500, -0.6250, 0.8750, 0.7500, 1.0000, 0.7500, 0.7500, 0.1250, + -0.5000, -1.0000, 0.0000, 0.7500, -0.8750, -0.1250, 1.0000, -0.1250, + 0.7500, 0.7500, -0.7500, -0.1250, 0.1250, -0.1250, 0.6250, 0.1250, + 0.7500, 0.6250, 0.5000, 0.8750, 1.0000, -0.6250, 0.5000, -0.6250, + 0.3750, 0.6250, -0.2500, -0.3750, -0.3750, 0.3750, 0.5000, -0.6250 + }; + + for (size_t idx = 0; idx < B*H*Q*dim; ++idx) { + q_buf[idx] = ov::float16(tmp_data[idx]); + } +} - /** Applies XAttention to the provided query and key matrices, returning the subset of the most important blocks for - * each attention head, according to the configured block size and threshold, which are to be preserved in the - * subsequent sparse attention computation. - * @param query_data Pointer to the query input tensor data - * @param query_shape Shape of the query input tensor data. Expected shape is [num_heads, num_query_tokens, - * head_size], where `num_query_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if - * necessary to do so in the real-world scenario. - * @param key_data Pointer to the key input tensor data - * @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size], - * where `num_key_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if necessary to - * do so in the real-world scenario. - * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block - * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks that - * must be preserved in the sparse attention computation. Indices are given in units of XAttention-specific - * `block_size` (as configured), which may differ from the block size in the paged attention implementation. - */ XAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, const Shape& query_shape, const T* key_data, const Shape& key_shape) { - OPENVINO_ASSERT(query_shape.size() == 3); // [num_heads, query_token_len, head_dim] - OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim] - + OPENVINO_ASSERT(query_shape.size() == 3); + OPENVINO_ASSERT(key_shape.size() == 3); OPENVINO_ASSERT(key_shape[0] == query_shape[0]); OPENVINO_ASSERT(key_shape[2] == query_shape[2]); - OPENVINO_ASSERT(query_shape[1] % m_stride == 0); OPENVINO_ASSERT(key_shape[1] % m_stride == 0); - OPENVINO_ASSERT(query_shape[1] % m_block_size == 0); OPENVINO_ASSERT(key_shape[1] % m_block_size == 0); + // print_query_data(query_data, {1, 32, 4}); - Shape reshaped_query_shape = {query_shape[0], query_shape[1] / m_stride, query_shape[2] * m_stride}; - auto q_buf = allocate_buf(reshaped_query_shape); - diagonal_reshape(query_data, query_shape, q_buf.get(), reshaped_query_shape, /* is_antidiagonal = */ true); + size_t chunk_size = query_shape[1]; + size_t k_len = key_shape[1]; + size_t head_dim = query_shape[2]; + size_t num_heads = query_shape[0]; + size_t k_num_to_pad = ((k_len + chunk_size - 1) / chunk_size) * chunk_size - k_len; + Shape pad_key_shape = {num_heads, k_len + k_num_to_pad, head_dim}; + auto pad_key_buf = allocate_buf(pad_key_shape); - Shape reshaped_key_shape = {key_shape[0], key_shape[1] / m_stride, key_shape[2] * m_stride}; - auto k_buf = allocate_buf(reshaped_key_shape); - diagonal_reshape(key_data, key_shape, k_buf.get(), reshaped_key_shape, /* is_antidiagonal = */ false); + for (size_t h = 0; h < num_heads; h++) + for (size_t t = 0; t < k_len; t++) + for (size_t d = 0; d < head_dim; d++) { + size_t offset = h * (k_len + k_num_to_pad) * head_dim + t * head_dim + d; + size_t original_offset = h * k_len * head_dim + t * head_dim + d; + pad_key_buf.get()[offset] = key_data[original_offset]; + } + + size_t k_chunk_num = (k_len + k_num_to_pad) / chunk_size; + size_t offset_token_chunk_num = k_chunk_num - 1; + size_t reshaped_chunk_size = chunk_size / m_stride; + // size_t reshaped_block_size = m_block_size / m_stride; + size_t k_reshaped_num_to_pad = k_num_to_pad / m_stride; + size_t k_reshaped_seq_len = (k_len + k_num_to_pad) / m_stride; + + // size_t num_blocks_per_chunk = reshaped_chunk_size / reshaped_block_size; + + // size_t q_block_num = chunk_size / m_block_size; + + // size_t k_block_num = (k_len + k_num_to_pad) / m_block_size; - Shape transpose_matmul_scaled_shape = {key_shape[0], query_shape[1] / m_stride, key_shape[1] / m_stride}; + Shape reshaped_query_shape = {num_heads, query_shape[1] / m_stride, head_dim * m_stride}; + auto q_buf = allocate_buf(reshaped_query_shape); + diagonal_reshape_kdb1_no_batch(query_data, query_shape, q_buf.get(), reshaped_query_shape); + Shape reshaped_key_shape = {num_heads, pad_key_shape[1] / m_stride, head_dim * m_stride}; + auto k_buf = allocate_buf(reshaped_key_shape); + diagonal_reshape(pad_key_buf.get(), pad_key_shape, k_buf.get(), reshaped_key_shape, false); + Shape transpose_matmul_scaled_shape = {num_heads, query_shape[1] / m_stride, pad_key_shape[1] / m_stride}; + std::cout << "transpose_matmul_scaled_shape: \n"; + for (auto ii : transpose_matmul_scaled_shape) { + std::cout << ii << " "; + } + std::cout << std::endl; auto qk_buf = allocate_buf(transpose_matmul_scaled_shape); + + + // print_q_buf(q_buf, num_heads, query_shape[1] / m_stride, head_dim * m_stride); + // set_q_buf(q_buf); + // print_q_buf(q_buf, num_heads, query_shape[1] / m_stride, head_dim * m_stride); + // print_k_buf(k_buf, num_heads, pad_key_shape[1] / m_stride, head_dim * m_stride); transpose_matmul_scale(q_buf.get(), k_buf.get(), reshaped_query_shape, reshaped_key_shape, qk_buf.get(), transpose_matmul_scaled_shape); + // print_qk_buf(qk_buf, num_heads, 16, 16); + q_buf.reset(); k_buf.reset(); + Shape causal_mask_shape = {num_heads, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num}; + auto causal_mask_buf = allocate_buf(causal_mask_shape); + std::fill(causal_mask_buf.get(), causal_mask_buf.get() + ov::shape_size(causal_mask_shape), T(0)); + if (k_reshaped_num_to_pad) { + for (size_t h = 0; h < num_heads; h++) + for (size_t q = 0; q < reshaped_chunk_size; q++) + for (size_t k = k_reshaped_seq_len - k_reshaped_num_to_pad; k < k_reshaped_seq_len; k++) { + size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + + q * (reshaped_chunk_size * k_chunk_num) + k; + + causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); + } + } + + size_t chunk_start = offset_token_chunk_num * reshaped_chunk_size; + + size_t chunk_end = chunk_start + reshaped_chunk_size; + + for (size_t h = 0; h < num_heads; h++) + for (size_t q = 0; q < reshaped_chunk_size; q++) + for (size_t k = q + 1; k < reshaped_chunk_size; k++) { + size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + + q * (reshaped_chunk_size * k_chunk_num) + chunk_start + k; + + causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); + } + + for (size_t h = 0; h < num_heads; h++) + for (size_t q = 0; q < reshaped_chunk_size; q++) + for (size_t k = chunk_end; k < reshaped_chunk_size * k_chunk_num; k++) { + size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + + q * (reshaped_chunk_size * k_chunk_num) + k; + + causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); + } + + // slice [: , : , 0 ::1 , : ] since kdb=1 + + size_t out_size = + transpose_matmul_scaled_shape[0] * transpose_matmul_scaled_shape[1] * transpose_matmul_scaled_shape[2]; + + + // print_causal_mask_buf(causal_mask_buf, num_heads, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num); + + for (size_t i = 0; i < out_size; i++) { + qk_buf.get()[i] += causal_mask_buf.get()[i]; + } + + + + causal_mask_buf.reset(); Shape attention_scores_shape = transpose_matmul_scaled_shape; + auto attn_score_buf = allocate_buf(attention_scores_shape); + + // print_qk_buf(qk_buf, num_heads, 16, 16); + // assign_qk_buf(qk_buf, num_heads, 16, 16); + // print_qk_buf(qk_buf, num_heads, 16, 16); + + softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); + qk_buf.reset(); + // print_attn_score_buf_with_shape(attn_score_buf, + // transpose_matmul_scaled_shape[0], + // transpose_matmul_scaled_shape[1], + // transpose_matmul_scaled_shape[2]); + + + + size_t antidiagonals_per_xattention_block = m_block_size / m_stride; Shape block_sum_shape = {attention_scores_shape[0], attention_scores_shape[1] / antidiagonals_per_xattention_block, attention_scores_shape[2] / antidiagonals_per_xattention_block}; + auto block_sum_buf = allocate_buf(block_sum_shape); block_sum_attention_scores(attn_score_buf.get(), attention_scores_shape, block_sum_buf.get(), block_sum_shape); attn_score_buf.reset(); - auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); block_sum_buf.reset(); + // The Python has the tril on the last q_block_num + + // So, to match, the simple_masks [: , : , -q_block_num : , -q_block_num : ] = where (tril, simple_masks, False) + + // But since the return is the set, we can do in the retained, erase the upper + + // Yes, already has. + return selected_block_indices; } - /** - * @param shape Shape of a tensor - * @return A shared_ptr owning a buffer that can be used to store tensor data for the given shape. - * */ std::shared_ptr allocate_buf(const Shape& shape) { return std::shared_ptr(new T[ov::shape_size(shape)]); } - /** - * @param token_length An integer value - * @return The closest multiple of `block_size` to `token_length`, rounding up. - * */ size_t pad_to_block(size_t token_length) { return (token_length + m_block_size - 1) / m_block_size * m_block_size; } double m_threshold; + size_t m_block_size; + size_t m_stride; }; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index b246637d3e7c9e..a640dba5b04545 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -104,28 +104,60 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { // stream.finish(); // std::cout << "finish xattn_estimate_gemmqk!\n"; res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; -#if DUMP_XATTN_BLOCK_MASK +// #if DUMP_XATTN_BLOCK_MASK + // { + // cldnn::stream& stream = instance.get_network().get_stream(); + // stream.finish(); + // static uint32_t pa_id = 0; + // std::cout << "finish xattn_estimate_find_block!\n"; + // auto output_mem = instance.get_intermediates_memories()[4]; + // mem_lock lock(output_mem, stream); + // auto& layout = output_mem->get_layout(); + // std::string data_type = ov::element::Type(layout.data_type).get_type_name(); + // std::string format = layout.format.to_string(); + // std::string tensor; + // auto dims = layout.get_dims(); + // for (size_t r = 0 ; r < layout.get_rank() ; r++) { + // tensor += ("_" + to_string(dims[r])); + // } + // // std::string filename = "PA" + std::to_string(pa_id) + "__" + data_type + "_" + tensor + "__" + format + ".bin"; + // std::string filename = "PA" + std::to_string(pa_id) + ".bin"; + // ov::util::save_binary(filename, lock.data(), output_mem->size()); + // pa_id++; + // } { cldnn::stream& stream = instance.get_network().get_stream(); stream.finish(); static uint32_t pa_id = 0; std::cout << "finish xattn_estimate_find_block!\n"; - auto output_mem = instance.get_intermediates_memories()[4]; - mem_lock lock(output_mem, stream); - auto& layout = output_mem->get_layout(); - std::string data_type = ov::element::Type(layout.data_type).get_type_name(); - std::string format = layout.format.to_string(); - std::string tensor; - auto dims = layout.get_dims(); - for (size_t r = 0 ; r < layout.get_rank() ; r++) { - tensor += ("_" + to_string(dims[r])); + for (int index = 0; index < 5; index++) { + auto output_mem = instance.get_intermediates_memories()[4]; + mem_lock lock(output_mem, stream); + auto& layout = output_mem->get_layout(); + auto dims = layout.get_dims(); + size_t total_size = output_mem->size(); + + std::cout << "PA" << pa_id << " layout: rank=" << layout.get_rank() + << ", dims=["; + for (size_t r = 0; r < dims.size(); r++) { + std::cout << dims[r]; + if (r != dims.size() - 1) std::cout << ","; + } + std::cout << "], total_size=" << total_size << "\n"; + + size_t max_print = total_size; //std::min(100, total_size); + std::cout << "Data: "; + for (size_t i = 0; i < max_print; i++) { + if (i % 32 == 0) std::cout << std::endl; + std::cout << static_cast(lock.data()[i]) << " "; + } + if (total_size > max_print) std::cout << "..."; + std::cout << "\n"; } - // std::string filename = "PA" + std::to_string(pa_id) + "__" + data_type + "_" + tensor + "__" + format + ".bin"; - std::string filename = "PA" + std::to_string(pa_id) + ".bin"; - ov::util::save_binary(filename, lock.data(), output_mem->size()); + pa_id++; } -#endif +// #endif res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; } res_event = {execute_stage(res_event, instance, pa_multi_token)}; diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index 067d7817a4a13e..9fe17c5274783b 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -16,6 +16,9 @@ #include "paged_attention_gpu_test.hpp" #include "random_generator.hpp" #include "test_utils.h" +#include +#include + using namespace cldnn; using namespace ov::intel_gpu; @@ -90,18 +93,18 @@ struct xAttentionReference { } private: - // void print_tensor(const std::vector& data, size_t heads, size_t rows, size_t cols, const std::string& name) { - // std::cout << name << " (" << heads << "x" << rows << "x" << cols << "):\n"; - // for (size_t h = 0; h < heads; h++) { - // std::cout << " Head " << h << ":\n"; - // for (size_t i = 0; i < rows; i++) { - // for (size_t j = 0; j < cols; j++) { - // std::cout << static_cast(data[h * rows * cols + i * cols + j]) << " "; - // } - // std::cout << "\n"; - // } - // } - // } + void print_tensor(const std::vector& data, size_t heads, size_t rows, size_t cols, const std::string& name) { + std::cout << name << " (" << heads << "x" << rows << "x" << cols << "):\n"; + for (size_t h = 0; h < heads; h++) { + std::cout << " Head " << h << ":\n"; + for (size_t i = 0; i < rows; i++) { + for (size_t j = 0; j < cols; j++) { + std::cout << static_cast(data[h * rows * cols + i * cols + j]) << "\n"; + } + std::cout << "\n"; + } + } + } std::vector softmax_1(const std::vector& logits) { std::vector out(logits.size()); @@ -239,6 +242,20 @@ struct xAttentionReference { return output; } + +// 保存为二进制 .bin 文件 +void save_tensor_to_bin(const std::string& filename, const std::vector& data) { + std::ofstream file(filename, std::ios::out | std::ios::binary); + if (!file) { + std::cerr << "Failed to open " << filename << " for writing" << std::endl; + return; + } + file.write(reinterpret_cast(data.data()), data.size() * sizeof(ov::float16)); + file.close(); + std::cout << "[Info] Saved " << filename << " (" << data.size() << " elements)" << std::endl; +} + + std::pair, std::vector> run_reference(const std::vector& query_data, const std::vector& key_data, const std::vector& value_data, @@ -250,8 +267,8 @@ struct xAttentionReference { int window_size, int sliding_window_size, float scale, - double threshold = 0.8, - size_t block_size = 256, + double threshold = 0.9, + size_t block_size = 128, size_t stride = 16) { // --- 1. allocate memory --- auto query_shape_bfyx = ov::PartialShape{1, num_queries, num_heads, k_head_size}; @@ -274,30 +291,6 @@ struct xAttentionReference { set_values(key_mem, key_data); set_values(value_mem, value_data); - // std::cout << "=== query_data (bfyx layout) ===" << std::endl; - // for (int q = 0; q < num_queries; q++) { - // for (int h = 0; h < num_heads; h++) { - // std::cout << "q=" << q << ", h=" << h << ": ["; - // for (int d = 0; d < k_head_size; d++) { - // auto val = query_data[q * num_heads * k_head_size + h * k_head_size + d]; - // std::cout << static_cast(val) << (d + 1 < k_head_size ? ", " : ""); - // } - // std::cout << "]" << std::endl; - // } - // } - - // std::cout << "=== key_data (bfyx layout) ===" << std::endl; - // for (int k = 0; k < num_keys; k++) { - // for (int h = 0; h < num_heads; h++) { - // std::cout << "k=" << k << ", h=" << h << ": ["; - // for (int d = 0; d < k_head_size; d++) { - // auto val = key_data[k * num_heads * k_head_size + h * k_head_size + d]; - // std::cout << static_cast(val) << (d + 1 < k_head_size ? ", " : ""); - // } - // std::cout << "]" << std::endl; - // } - // } - std::vector query_data_3d(num_heads * num_queries * k_head_size); std::vector key_data_3d(num_heads * num_keys * k_head_size); @@ -321,19 +314,102 @@ struct xAttentionReference { ov::Shape key_shape_3d = {static_cast(num_heads), static_cast(num_keys), static_cast(k_head_size)}; ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; - { - ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); - retained_blocks = selector.select_blocks(query_data_3d.data(), query_shape_3d, key_data_3d.data(), key_shape_3d); - - // std::cout << "=== C++ 选中 blocks ===" << std::endl; - // for (size_t h = 0; h < retained_blocks.size(); ++h) { - // std::cout << "Head " << h << " selected blocks: "; - // for (const auto& idx_pair : retained_blocks[h]) { - // std::cout << "(" << idx_pair.first << "," << idx_pair.second << ") "; - // } - // std::cout << std::endl; - // } + // { + // ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); + // retained_blocks = selector.select_blocks(query_data_3d.data(), query_shape_3d, key_data_3d.data(), key_shape_3d); + + // std::cout << "=== C++ 选中 blocks ===" << std::endl; + // for (size_t h = 0; h < retained_blocks.size(); ++h) { + // std::cout << "Head " << h << " selected blocks: "; + // for (const auto& idx_pair : retained_blocks[h]) { + // std::cout << "(" << idx_pair.first << "," << idx_pair.second << ") "; + // } + // std::cout << std::endl; + // } + // } + + + if (num_queries < static_cast(block_size)) { + // Case 1: too few queries — skip block selection + std::cout << "[Info] num_queries < block_size, skip block selection." << std::endl; + } else { + // Case 2: handle non-divisible length via padding + size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; + size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; + + if (padded_q != static_cast(num_queries) || padded_k != static_cast(num_keys)) { + std::cout << "[Info] Padding Q/K length for block alignment: " + << "Q " << num_queries << "→" << padded_q + << ", K " << num_keys << "→" << padded_k << std::endl; + } + + // Build padded buffers for selection + std::vector query_padded(num_heads * padded_q * k_head_size, ov::float16(0)); + std::vector key_padded(num_heads * padded_k * k_head_size, ov::float16(0)); + + for (int h = 0; h < num_heads; ++h) { + std::copy_n(&query_data_3d[h * num_queries * k_head_size], + num_queries * k_head_size, + &query_padded[h * padded_q * k_head_size]); + std::copy_n(&key_data_3d[h * num_keys * k_head_size], + num_keys * k_head_size, + &key_padded[h * padded_k * k_head_size]); + } + + ov::Shape query_shape_padded = {static_cast(num_heads), padded_q, static_cast(k_head_size)}; + ov::Shape key_shape_padded = {static_cast(num_heads), padded_k, static_cast(k_head_size)}; + + + // === Save padded Q/K for Python comparison === + save_tensor_to_bin("q_padded.bin", query_padded); + save_tensor_to_bin("k_padded.bin", key_padded); + + std::ofstream meta("meta.txt"); + meta << "num_heads=" << num_heads << "\n"; + meta << "padded_q=" << padded_q << "\n"; + meta << "padded_k=" << padded_k << "\n"; + meta << "k_head_size=" << k_head_size << "\n"; + meta << "block_size=" << block_size << "\n"; + meta << "stride=" << stride << "\n"; + meta << "threshold=" << threshold << "\n"; + meta.close(); + std::cout << "[Info] Saved meta.txt with shape info" << std::endl; + + + std::vector query_padded_f32(query_padded.size()); + std::vector key_padded_f32(key_padded.size()); + for (size_t i = 0; i < query_padded.size(); ++i) + query_padded_f32[i] = static_cast(query_padded[i]); + for (size_t i = 0; i < key_padded.size(); ++i) + key_padded_f32[i] = static_cast(key_padded[i]); + + ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); + retained_blocks = selector.select_blocks(query_padded_f32.data(), query_shape_padded, + key_padded_f32.data(), key_shape_padded); + + std::cout << "=== Selected blocks after padding ===" << std::endl; + for (size_t h = 0; h < retained_blocks.size(); ++h) { + std::cout << "Head " << h << " selected blocks: "; + for (const auto& idx_pair : retained_blocks[h]) { + std::cout << "(" << idx_pair.first << "," << idx_pair.second << ") "; + } + std::cout << std::endl; } + } + + // retained_blocks = { + // { // Head 0 + // {0,0}, {1,0}, {1,1}, {2,0}, {2,1}, {2,2}, {3,0}, {3,1}, {3,2}, {3,3}, {4,0}, {4,1}, {4,2}, {4,3}, {4,4}, {5,0}, {5,1}, {5,2}, {5,3}, {5,4}, {5,5}, {6,0}, {6,1}, {6,2}, {6,3}, {6,4}, {6,5}, {6,6}, {7,0}, {7,1}, {7,2}, {7,3}, {7,4}, {7,5}, {7,6}, {7,7}, {8,0}, {8,1}, {8,2}, {8,3}, {8,4}, {8,5}, {8,6}, {8,7}, {8,8}, {9,0}, {9,1}, {9,2}, {9,3}, {9,4}, {9,5}, {9,6}, {9,7}, {9,8}, {9,9}, {10,0}, {10,1}, {10,2}, {10,3}, {10,5}, {10,6}, {10,7}, {10,8}, {10,9}, {10,10}, {11,0}, {11,1}, {11,2}, {11,3}, {11,4}, {11,5}, {11,7}, {11,8}, {11,9}, {11,10}, {11,11}, {12,0}, {12,1}, {12,2}, {12,3}, {12,4}, {12,5}, {12,6}, {12,7}, {12,8}, {12,9}, {12,10}, {12,12}, {13,0}, {13,1}, {13,2}, {13,3}, {13,4}, {13,5}, {13,6}, {13,7}, {13,8}, {13,9}, {13,11}, {13,12}, {13,13}, {14,0}, {14,1}, {14,2}, {14,3}, {14,4}, {14,5}, {14,6}, {14,8}, {14,9}, {14,10}, {14,11}, {14,12}, {14,13}, {14,14}, {15,0}, {15,1}, {15,2}, {15,3}, {15,4}, {15,5}, {15,6}, {15,8}, {15,9}, {15,10}, {15,11}, {15,12}, {15,13}, {15,14}, {15,15}, {16,0}, {16,1}, {16,2}, {16,3}, {16,4}, {16,5}, {16,6}, {16,7}, {16,8}, {16,9}, {16,10}, {16,11}, {16,12}, {16,13}, {16,14}, {16,16}, {17,0}, {17,1}, {17,2}, {17,3}, {17,4}, {17,5}, {17,6}, {17,7}, {17,8}, {17,9}, {17,10}, {17,11}, {17,12}, {17,13}, {17,14}, {17,15}, {17,17}, {18,0}, {18,1}, {18,2}, {18,3}, {18,4}, {18,5}, {18,6}, {18,7}, {18,9}, {18,10}, {18,11}, {18,12}, {18,13}, {18,14}, {18,15}, {18,16}, {18,17}, {18,18}, {19,0}, {19,1}, {19,3}, {19,4}, {19,5}, {19,6}, {19,7}, {19,8}, {19,9}, {19,10}, {19,11}, {19,12}, {19,13}, {19,14}, {19,15}, {19,16}, {19,17}, {19,18}, {19,19}, {20,0}, {20,1}, {20,2}, {20,3}, {20,4}, {20,5}, {20,6}, {20,7}, {20,9}, {20,10}, {20,12}, {20,13}, {20,14}, {20,15}, {20,16}, {20,17}, {20,18}, {20,19}, {20,20}, {21,0}, {21,1}, {21,2}, {21,3}, {21,4}, {21,5}, {21,6}, {21,7}, {21,8}, {21,9}, {21,10}, {21,12}, {21,13}, {21,14}, {21,16}, {21,17}, {21,18}, {21,19}, {21,20}, {21,21}, {22,0}, {22,1}, {22,2}, {22,3}, {22,4}, {22,5}, {22,7}, {22,8}, {22,10}, {22,11}, {22,12}, {22,13}, {22,14}, {22,15}, {22,16}, {22,17}, {22,18}, {22,19}, {22,20}, {22,21}, {22,22}, {23,0}, {23,1}, {23,2}, {23,3}, {23,4}, {23,5}, {23,6}, {23,7}, {23,9}, {23,10}, {23,11}, {23,13}, {23,14}, {23,15}, {23,16}, {23,17}, {23,18}, {23,19}, {23,20}, {23,21}, {23,22}, {23,23}, {24,0}, {24,2}, {24,3}, {24,4}, {24,5}, {24,6}, {24,7}, {24,9}, {24,10}, {24,11}, {24,12}, {24,13}, {24,14}, {24,15}, {24,16}, {24,17}, {24,18}, {24,19}, {24,20}, {24,21}, {24,22}, {24,23}, {24,24}, {25,0}, {25,1}, {25,2}, {25,3}, {25,4}, {25,5}, {25,6}, {25,7}, {25,8}, {25,9}, {25,10}, {25,11}, {25,12}, {25,13}, {25,14}, {25,15}, {25,16}, {25,17}, {25,19}, {25,20}, {25,22}, {25,23}, {25,24}, {25,25}, {26,0}, {26,1}, {26,2}, {26,3}, {26,4}, {26,5}, {26,6}, {26,7}, {26,8}, {26,9}, {26,10}, {26,12}, {26,13}, {26,14}, {26,15}, {26,16}, {26,17}, {26,18}, {26,19}, {26,20}, {26,21}, {26,23}, {26,24}, {26,25}, {26,26}, {27,0}, {27,1}, {27,3}, {27,4}, {27,5}, {27,6}, {27,7}, {27,8}, {27,9}, {27,10}, {27,12}, {27,13}, {27,14}, {27,15}, {27,16}, {27,17}, {27,18}, {27,19}, {27,20}, {27,21}, {27,22}, {27,23}, {27,24}, {27,25}, {27,26}, {27,27}, {28,0}, {28,1}, {28,2}, {28,3}, {28,4}, {28,5}, {28,6}, {28,7}, {28,8}, {28,9}, {28,11}, {28,12}, {28,13}, {28,14}, {28,15}, {28,16}, {28,17}, {28,18}, {28,19}, {28,20}, {28,21}, {28,22}, {28,23}, {28,24}, {28,25}, {28,26}, {28,28}, {29,0}, {29,1}, {29,2}, {29,3}, {29,4}, {29,5}, {29,6}, {29,7}, {29,8}, {29,9}, {29,11}, {29,12}, {29,13}, {29,14}, {29,15}, {29,17}, {29,18}, {29,19}, {29,20}, {29,21}, {29,22}, {29,23}, {29,24}, {29,25}, {29,26}, {29,27}, {29,29}, {30,0}, {30,1}, {30,2}, {30,3}, {30,4}, {30,5}, {30,6}, {30,7}, {30,8}, {30,9}, {30,10}, {30,13}, {30,14}, {30,15}, {30,16}, {30,17}, {30,18}, {30,19}, {30,20}, {30,21}, {30,23}, {30,24}, {30,25}, {30,26}, {30,27}, {30,28}, {30,29}, {30,30}, {31,0}, {31,1}, {31,2}, {31,4}, {31,5}, {31,6}, {31,8}, {31,9}, {31,10}, {31,11}, {31,12}, {31,13}, {31,14}, {31,15}, {31,16}, {31,17}, {31,18}, {31,19}, {31,20}, {31,21}, {31,22}, {31,23}, {31,25}, {31,26}, {31,27}, {31,28}, {31,29}, {31,30}, {31,31} + // }, + // { // Head 1 + // {0,0}, {1,0}, {1,1}, {2,0}, {2,1}, {2,2}, {3,0}, {3,1}, {3,2}, {3,3}, {4,0}, {4,1}, {4,2}, {4,3}, {4,4}, {5,0}, {5,1}, {5,2}, {5,3}, {5,4}, {5,5}, {6,0}, {6,1}, {6,2}, {6,3}, {6,4}, {6,5}, {6,6}, {7,0}, {7,1}, {7,2}, {7,3}, {7,4}, {7,5}, {7,6}, {7,7}, {8,0}, {8,1}, {8,2}, {8,3}, {8,4}, {8,5}, {8,6}, {8,7}, {8,8}, {9,0}, {9,1}, {9,2}, {9,3}, {9,4}, {9,5}, {9,6}, {9,7}, {9,8}, {9,9}, {10,0}, {10,1}, {10,2}, {10,3}, {10,4}, {10,6}, {10,7}, {10,8}, {10,9}, {10,10}, {11,0}, {11,1}, {11,2}, {11,3}, {11,4}, {11,5}, {11,6}, {11,8}, {11,9}, {11,10}, {11,11}, {12,0}, {12,1}, {12,2}, {12,3}, {12,5}, {12,6}, {12,7}, {12,8}, {12,9}, {12,10}, {12,11}, {12,12}, {13,0}, {13,1}, {13,2}, {13,3}, {13,4}, {13,5}, {13,6}, {13,8}, {13,9}, {13,10}, {13,11}, {13,12}, {13,13}, {14,0}, {14,1}, {14,2}, {14,3}, {14,4}, {14,5}, {14,6}, {14,8}, {14,9}, {14,10}, {14,11}, {14,12}, {14,13}, {14,14}, {15,0}, {15,1}, {15,2}, {15,3}, {15,4}, {15,5}, {15,6}, {15,7}, {15,8}, {15,9}, {15,10}, {15,11}, {15,12}, {15,14}, {15,15}, {16,0}, {16,1}, {16,2}, {16,3}, {16,4}, {16,5}, {16,7}, {16,8}, {16,9}, {16,10}, {16,11}, {16,12}, {16,13}, {16,14}, {16,15}, {16,16}, {17,0}, {17,2}, {17,3}, {17,4}, {17,5}, {17,6}, {17,7}, {17,8}, {17,9}, {17,10}, {17,11}, {17,12}, {17,13}, {17,14}, {17,15}, {17,16}, {17,17}, {18,0}, {18,1}, {18,2}, {18,3}, {18,4}, {18,5}, {18,6}, {18,7}, {18,8}, {18,9}, {18,10}, {18,11}, {18,12}, {18,13}, {18,14}, {18,15}, {18,17}, {18,18}, {19,0}, {19,1}, {19,2}, {19,3}, {19,4}, {19,5}, {19,6}, {19,7}, {19,8}, {19,9}, {19,10}, {19,11}, {19,12}, {19,13}, {19,15}, {19,16}, {19,17}, {19,18}, {19,19}, {20,0}, {20,1}, {20,2}, {20,3}, {20,4}, {20,5}, {20,6}, {20,7}, {20,8}, {20,10}, {20,11}, {20,12}, {20,13}, {20,14}, {20,15}, {20,16}, {20,18}, {20,19}, {20,20}, {21,0}, {21,1}, {21,2}, {21,4}, {21,5}, {21,6}, {21,7}, {21,9}, {21,10}, {21,11}, {21,12}, {21,13}, {21,14}, {21,15}, {21,16}, {21,17}, {21,18}, {21,19}, {21,20}, {21,21}, {22,0}, {22,1}, {22,2}, {22,3}, {22,4}, {22,5}, {22,7}, {22,8}, {22,9}, {22,10}, {22,11}, {22,12}, {22,13}, {22,14}, {22,15}, {22,16}, {22,17}, {22,18}, {22,20}, {22,21}, {22,22}, {23,0}, {23,1}, {23,2}, {23,3}, {23,5}, {23,6}, {23,7}, {23,8}, {23,9}, {23,10}, {23,11}, {23,12}, {23,13}, {23,14}, {23,15}, {23,16}, {23,18}, {23,19}, {23,20}, {23,21}, {23,22}, {23,23}, {24,0}, {24,1}, {24,2}, {24,3}, {24,4}, {24,5}, {24,6}, {24,7}, {24,9}, {24,10}, {24,11}, {24,13}, {24,14}, {24,15}, {24,16}, {24,17}, {24,18}, {24,19}, {24,20}, {24,21}, {24,22}, {24,23}, {24,24}, {25,0}, {25,1}, {25,2}, {25,3}, {25,4}, {25,5}, {25,6}, {25,7}, {25,8}, {25,10}, {25,11}, {25,12}, {25,13}, {25,14}, {25,15}, {25,16}, {25,17}, {25,18}, {25,19}, {25,20}, {25,21}, {25,22}, {25,24}, {25,25}, {26,0}, {26,1}, {26,2}, {26,3}, {26,4}, {26,5}, {26,6}, {26,7}, {26,8}, {26,9}, {26,10}, {26,11}, {26,12}, {26,13}, {26,15}, {26,16}, {26,17}, {26,18}, {26,19}, {26,20}, {26,21}, {26,23}, {26,24}, {26,25}, {26,26}, {27,0}, {27,1}, {27,2}, {27,3}, {27,4}, {27,5}, {27,6}, {27,7}, {27,8}, {27,9}, {27,10}, {27,11}, {27,12}, {27,13}, {27,14}, {27,16}, {27,17}, {27,18}, {27,19}, {27,20}, {27,22}, {27,23}, {27,24}, {27,25}, {27,26}, {27,27}, {28,0}, {28,1}, {28,2}, {28,3}, {28,4}, {28,5}, {28,6}, {28,7}, {28,8}, {28,9}, {28,11}, {28,12}, {28,13}, {28,14}, {28,15}, {28,16}, {28,17}, {28,18}, {28,19}, {28,20}, {28,21}, {28,22}, {28,24}, {28,25}, {28,26}, {28,27}, {28,28}, {29,0}, {29,1}, {29,2}, {29,3}, {29,4}, {29,5}, {29,7}, {29,8}, {29,9}, {29,11}, {29,12}, {29,13}, {29,14}, {29,15}, {29,16}, {29,17}, {29,18}, {29,19}, {29,20}, {29,21}, {29,22}, {29,23}, {29,24}, {29,25}, {29,27}, {29,28}, {29,29}, {30,0}, {30,1}, {30,2}, {30,3}, {30,4}, {30,6}, {30,7}, {30,8}, {30,9}, {30,10}, {30,11}, {30,12}, {30,14}, {30,15}, {30,16}, {30,17}, {30,18}, {30,19}, {30,20}, {30,21}, {30,22}, {30,24}, {30,25}, {30,26}, {30,27}, {30,28}, {30,29}, {30,30}, {31,0}, {31,1}, {31,2}, {31,3}, {31,4}, {31,5}, {31,6}, {31,7}, {31,9}, {31,10}, {31,11}, {31,12}, {31,14}, {31,15}, {31,16}, {31,17}, {31,18}, {31,19}, {31,20}, {31,21}, {31,23}, {31,24}, {31,25}, {31,26}, {31,27}, {31,28}, {31,29}, {31,30}, {31,31} + // } + // }; + +// retained_blocks = {{ // Head 0 +// {0,0}, {1,0}, {1,1}, {2,0}, {2,1}, {2,2}, {3,0}, {3,1}, {3,2}, {3,3}, {4,0}, {4,1}, {4,2}, {4,3}, {4,4}, {5,0}, {5,1}, {5,2}, {5,3}, {5,4}, {5,5}, {6,0}, {6,1}, {6,2}, {6,3}, {6,4}, {6,5}, {6,6}, {7,0}, {7,1}, {7,2}, {7,3}, {7,4}, {7,5}, {7,6}, {7,7}, {8,0}, {8,1}, {8,2}, {8,3}, {8,4}, {8,5}, {8,6}, {8,7}, {8,8}, {9,0}, {9,1}, {9,2}, {9,3}, {9,4}, {9,5}, {9,6}, {9,7}, {9,8}, {9,9}, {10,0}, {10,1}, {10,2}, {10,3}, {10,4}, {10,5}, {10,6}, {10,7}, {10,9}, {10,10}, {11,0}, {11,1}, {11,2}, {11,3}, {11,4}, {11,5}, {11,6}, {11,7}, {11,9}, {11,10}, {11,11}, {12,0}, {12,1}, {12,2}, {12,3}, {12,4}, {12,6}, {12,7}, {12,8}, {12,9}, {12,10}, {12,11}, {12,12}, {13,0}, {13,2}, {13,3}, {13,4}, {13,5}, {13,6}, {13,7}, {13,8}, {13,9}, {13,10}, {13,11}, {13,12}, {13,13}, {14,0}, {14,1}, {14,2}, {14,3}, {14,4}, {14,5}, {14,7}, {14,8}, {14,9}, {14,10}, {14,11}, {14,12}, {14,13}, {14,14}, {15,0}, {15,1}, {15,2}, {15,3}, {15,5}, {15,6}, {15,7}, {15,8}, {15,9}, {15,10}, {15,11}, {15,12}, {15,13}, {15,14}, {15,15}, {16,0}, {16,1}, {16,2}, {16,3}, {16,4}, {16,5}, {16,6}, {16,7}, {16,8}, {16,9}, {16,10}, {16,12}, {16,13}, {16,14}, {16,15}, {16,16}, {17,0}, {17,1}, {17,2}, {17,4}, {17,5}, {17,6}, {17,7}, {17,8}, {17,9}, {17,10}, {17,11}, {17,12}, {17,13}, {17,14}, {17,15}, {17,16}, {17,17}, {18,0}, {18,1}, {18,2}, {18,3}, {18,4}, {18,5}, {18,7}, {18,8}, {18,9}, {18,10}, {18,11}, {18,12}, {18,13}, {18,14}, {18,15}, {18,16}, {18,17}, {18,18}, {19,0}, {19,1}, {19,2}, {19,4}, {19,5}, {19,6}, {19,7}, {19,8}, {19,9}, {19,10}, {19,11}, {19,12}, {19,13}, {19,14}, {19,15}, {19,16}, {19,17}, {19,18}, {19,19}, {20,0}, {20,1}, {20,2}, {20,3}, {20,4}, {20,5}, {20,6}, {20,7}, {20,9}, {20,10}, {20,11}, {20,13}, {20,14}, {20,15}, {20,16}, {20,17}, {20,18}, {20,19}, {20,20}, {21,0}, {21,1}, {21,2}, {21,3}, {21,5}, {21,6}, {21,7}, {21,8}, {21,9}, {21,10}, {21,11}, {21,12}, {21,14}, {21,15}, {21,16}, {21,17}, {21,18}, {21,19}, {21,20}, {21,21}, {22,0}, {22,1}, {22,2}, {22,3}, {22,4}, {22,5}, {22,7}, {22,8}, {22,9}, {22,10}, {22,11}, {22,12}, {22,14}, {22,15}, {22,16}, {22,17}, {22,18}, {22,19}, {22,20}, {22,21}, {22,22}, {23,0}, {23,1}, {23,2}, {23,3}, {23,4}, {23,5}, {23,7}, {23,8}, {23,9}, {23,10}, {23,11}, {23,12}, {23,13}, {23,14}, {23,16}, {23,17}, {23,18}, {23,19}, {23,20}, {23,21}, {23,22}, {23,23}, {24,0}, {24,1}, {24,2}, {24,3}, {24,5}, {24,6}, {24,7}, {24,8}, {24,10}, {24,11}, {24,12}, {24,13}, {24,14}, {24,15}, {24,16}, {24,17}, {24,18}, {24,19}, {24,20}, {24,21}, {24,22}, {24,23}, {24,24}, {25,0}, {25,1}, {25,2}, {25,3}, {25,6}, {25,7}, {25,8}, {25,9}, {25,10}, {25,11}, {25,12}, {25,13}, {25,14}, {25,15}, {25,16}, {25,17}, {25,18}, {25,19}, {25,20}, {25,21}, {25,22}, {25,23}, {25,24}, {25,25}, {26,0}, {26,1}, {26,2}, {26,3}, {26,4}, {26,5}, {26,6}, {26,7}, {26,8}, {26,9}, {26,10}, {26,11}, {26,12}, {26,13}, {26,14}, {26,15}, {26,16}, {26,17}, {26,18}, {26,20}, {26,21}, {26,22}, {26,23}, {26,24}, {26,26}, {27,0}, {27,1}, {27,2}, {27,3}, {27,4}, {27,7}, {27,8}, {27,9}, {27,10}, {27,11}, {27,12}, {27,13}, {27,14}, {27,15}, {27,16}, {27,17}, {27,18}, {27,19}, {27,20}, {27,21}, {27,22}, {27,23}, {27,24}, {27,25}, {27,26}, {27,27}, {28,0}, {28,1}, {28,2}, {28,3}, {28,4}, {28,5}, {28,6}, {28,7}, {28,9}, {28,10}, {28,11}, {28,12}, {28,13}, {28,14}, {28,15}, {28,16}, {28,17}, {28,18}, {28,19}, {28,20}, {28,21}, {28,22}, {28,23}, {28,25}, {28,26}, {28,27}, {28,28}, {29,0}, {29,1}, {29,2}, {29,3}, {29,4}, {29,5}, {29,6}, {29,7}, {29,8}, {29,9}, {29,10}, {29,11}, {29,12}, {29,13}, {29,14}, {29,15}, {29,16}, {29,17}, {29,18}, {29,19}, {29,22}, {29,23}, {29,24}, {29,25}, {29,26}, {29,27}, {29,28}, {29,29}, {30,0}, {30,1}, {30,2}, {30,3}, {30,4}, {30,5}, {30,6}, {30,7}, {30,8}, {30,9}, {30,10}, {30,11}, {30,12}, {30,13}, {30,14}, {30,15}, {30,16}, {30,17}, {30,18}, {30,19}, {30,20}, {30,22}, {30,23}, {30,24}, {30,25}, {30,27}, {30,28}, {30,30}, {31,0}, {31,1}, {31,2}, {31,3}, {31,4}, {31,5}, {31,6}, {31,8}, {31,9}, {31,10}, {31,11}, {31,12}, {31,13}, {31,14}, {31,15}, {31,16}, {31,17}, {31,18}, {31,19}, {31,20}, {31,22}, {31,23}, {31,24}, {31,25}, {31,27}, {31,28}, {31,29}, {31,30}, {31,31} +// }}; + // auto output = compute_sparse_causal_attention(query_data, // key_data, @@ -349,6 +425,7 @@ struct xAttentionReference { // print_tensor(output, num_heads, num_queries, k_head_size, "Output"); auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, block_size); + // auto mask_mem = get_mask_mem(num_queries, num_keys, num_heads, sliding_window_size); topology topology; topology.add(input_layout("query", query_layout), @@ -420,6 +497,8 @@ struct xAttentionReference { int block_size) { // mask layout: [1, num_heads, num_queries, num_keys] auto mask_shape = ov::PartialShape{1, num_heads, num_queries, num_keys}; + std::cout << "**********************************************************************\n"; + std::cout << num_heads << " " << num_queries << " " << num_keys << std::endl; auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; auto mask_mem = test_engine.allocate_memory(mask_layout); @@ -502,6 +581,77 @@ struct xAttentionReference { return mask_mem; } + memory::ptr get_mask_mem(int num_queries, int num_keys, int num_heads, int sliding_window_size) { + /* + * Two kinds of masks: + * + * Case 1 (N == K): + * num_queries = N + * num_keys = K = N + * k_head_size = H + * Q [N, H] * K[H, N] + * QK [N, N] + * 0 1 N + * 0 [ 0, MIN, .., MIN ] + * 1 [ 0, 0, .., MIN ] + * [ .., .., .., MIN ] + * N [ 0, 0, .., 0 ] + * + * Case 2 (N != K): + * num_queries = N + * num_keys = K + * k_head_size = H + * past_len = P = K - N + 1 + * Q [N, H] * K[H, K] + * QK [N, K] + * 0 1 2 P .. K + * 0 [ 0, 0, 0, MIN, MIN, MIN ] + * 1 [ 0, 0, 0, 0, MIN, MIN ] + * [ .., .., .., .., .., MIN ] + * N [ 0, 0, 0, 0, .., 0 ] + * + * Shapes: + * Q [1, num_heads, num_queries, k_head_size] + * K [1, num_heads, k_head_size, num_keys] + * Q*K [1, num_heads, num_queries, num_keys] + */ + + auto mask_shape = ov::PartialShape{ 1, 1, num_queries, num_keys }; + auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; + auto mask_mem = test_engine.allocate_memory(mask_layout); + + mem_lock mem_ptr(mask_mem, test_stream); + + if (sliding_window_size == 0) { + int past_len = num_keys - num_queries + 1; + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest() + : ov::float16(0.f); + } + } + } else { + int sliding_left = num_keys - num_queries - sliding_window_size + 1; + int past_len = num_keys - num_queries + 1; + + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + bool is_min; + if (num_queries == num_keys) { + is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; + } else { + is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; + } + + mem_ptr[i * num_keys + j] = is_min ? std::numeric_limits::lowest() : ov::float16(0.f); + } + } + } + + return mask_mem; + } + + void rotate_block(std::vector& cache_data, std::vector rotation_deltas, std::vector rotation_trig_lut_mem, @@ -782,9 +932,9 @@ struct xAttentionTest : public ::testing::TestWithParam { output_scores_mem = outputs.at("output_scores").get_memory(); } auto ref_data = xAttentionReference(pam).get_reference(); - for (size_t i = 0; i < ref_data.first.size(); i++) { - std::cout << i << "reference = " << ref_data.first[i] << std::endl; - } + // for (size_t i = 0; i < ref_data.first.size(); i++) { + // std::cout << i << "reference = " << ref_data.first[i] << std::endl; + // } compare(output_data_mem, output_scores_mem, ref_data); } @@ -795,9 +945,16 @@ struct xAttentionTest : public ::testing::TestWithParam { for (size_t i = 0; i < data_output_mem->count(); i++) { std::cout << i << ": result = " << mem_ptr[i] << ", reference = " << ref_data.first[i] << std::endl; } - // for (size_t i = 0; i < data_output_mem->count(); i++) { - // ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance) << " at index=" << i; - // } + std::cout << "data_output_mem->count(): " << data_output_mem->count() << std::endl; + int num = 0; + for (size_t i = 0; i < data_output_mem->count(); i++) { + if (abs(mem_ptr[i] - ref_data.first[i]) > tolerance) { + // std::cout << "mem_ptr: " << mem_ptr[i] << " " << "ref_data: " << ref_data.first[i] << std::endl; + num++; + } + // ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance) << " at index=" << i; + } + std::cout << "num: " << num << std::endl; } if (scores_output_mem) { @@ -852,8 +1009,9 @@ INSTANTIATE_TEST_SUITE_P(smoke_xattention, #if ENABLE_PA_CM_PATH /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ - // xattention_test_params{ {{32, 0}}, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + // xattention_test_params{ {{32, 0}}, 2, 2, 2, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + xattention_test_params{ {{4096, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + // xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long // xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, // DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long diff --git a/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp b/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp index 60e5355894f62b..bd234c5bd42aeb 100644 --- a/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp +++ b/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp @@ -598,6 +598,50 @@ struct PagedAttentionManager { return data; } +static std::vector generate_input_data_ww( + tests::random_generator& rg, + size_t num_heads, + size_t tokens_num, + size_t k_head_size, + float stddev = 0.5f, // 控制数据分布集中程度 + bool normalize = true // 是否对每个向量做归一化 +) { + const size_t total_elements_num = tokens_num * num_heads * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + // 将均匀分布映射到近似正态分布 + for (size_t i = 0; i < total_elements_num; ++i) { + float x = static_cast(data[i]); + // Box-Muller transform for simple Gaussian-like distribution + float u1 = (x + 1.f) / 2.f; // [0,1] + float u2 = rg.generate_random_1d(1, 0.f, 1.f)[0]; // 另一个随机数 + float r = std::sqrt(-2.f * std::log(u1 + 1e-6f)) * stddev; // 避免 log(0) + float theta = 2.f * 3.1415926535f * u2; + float val = r * std::cos(theta); + data[i] = ov::float16(val); + } + + if (normalize) { + // 对每个 head 的每个 token 做 L2 归一化 + for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { + for (size_t token_idx = 0; token_idx < tokens_num; ++token_idx) { + float norm = 0.f; + for (size_t dim = 0; dim < k_head_size; ++dim) { + float val = static_cast(data[head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim]); + norm += val * val; + } + norm = std::sqrt(norm) + 1e-6f; + for (size_t dim = 0; dim < k_head_size; ++dim) { + size_t idx = head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim; + data[idx] = ov::float16(static_cast(data[idx]) / norm); + } + } + } + } + + return data; +} + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { const size_t total_elements_num = per_block ? rotated_blocks_num : rotated_blocks_num * block_size; From 3da8a3464210b8a299cd058f550ce98a1b96600a Mon Sep 17 00:00:00 2001 From: Luwei Zhou Date: Tue, 14 Oct 2025 14:25:03 +0800 Subject: [PATCH 58/96] Fix the KV cache padding with Nan issue for 1st token. --- .../graph/impls/cm/include/cm_pa_common.hpp | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp index 59ef0081413dec..f9a7e3fd1aaa03 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp @@ -362,6 +362,12 @@ void pa_kernel_lsc_prefetch_f16( b2dK.set_base_ptr((reinterpret_cast(k_cache_base)+cur_block_id*blk_stride)); b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); cm_load(Kmat.format(), b2dK.set_block_x(0)); + // somtimes KV cache would be filled with random Nan, so need to clean up the unused key data. + if ((kv_pos + kv_step) > kv_stop) { + auto valid_rows = kv_stop - kv_pos; + for (int r = valid_rows; r < kv_step; r++) + Kmat.format().row(r) = 0.f; + } #pragma unroll for(int k = 0; k < num_K; k++) St2.row(k) = cm_dpas( @@ -415,6 +421,15 @@ void pa_kernel_lsc_prefetch_f16( matrix Vmat; cm_prefetch(prefetch_V.set_block_x(k)); cm_load(Vmat.format(), b2dV.set_block_x(k)); + // somtimes KV cache would be filled with random Nan, so need to clean up the unused value data. + if ((kv_pos + kv_step) > kv_stop) { + uint valid_rows = kv_stop - kv_pos; + uint valid_rows_vnni = (valid_rows+1)/2; + for (int r = valid_rows_vnni; r < kv_step / 2; r++) + Vmat.row(r) = 0.f; + if (valid_rows % 2 == 1) + Vmat.row(valid_rows_vnni-1).select(1) = 0.f; + } #pragma unroll for(int p = 0; p < num_P_tiles; p++) { rO[ri + p] = cm_dpas( @@ -433,7 +448,15 @@ void pa_kernel_lsc_prefetch_f16( cm_prefetch(prefetch_V.set_block_x(k)); cm_load(Vmat.format(), b2dV.set_block_x(k)); - + // somtimes KV cache would be filled with random Nan, so need to clean up the unused value data. + if ((kv_pos + kv_step) > kv_stop) { + uint valid_rows = kv_stop - kv_pos; + uint valid_rows_vnni = (valid_rows+1)/2; + for (int r = valid_rows_vnni; r < kv_step / 2; r++) + Vmat.row(r) = 0.f; + if (valid_rows % 2 == 1) + Vmat.row(valid_rows_vnni-1).select(1) = 0.f; + } //# compensate cur_O // matrix rO; #pragma unroll From 2dd7a813c6099142a6281b4f3db3c8464d671361 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Tue, 14 Oct 2025 16:34:52 +0800 Subject: [PATCH 59/96] Fix nan issue for 2nd token --- .../graph/impls/cm/pa_kv_cache_update_ref.cm | 24 +------------------ .../src/graph/impls/cm/pa_single_token.cm | 17 +++++++++++++ .../graph/impls/cm/paged_attention_gen.cpp | 4 ++-- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm index 7016dcddcc6bb9..fc7563e567ab3c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -70,7 +70,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const uint token_idx = cm_global_id(2); // token_idx -> subsequence_idx - // if (token_idx >= subsequence_begins[batch_size_in_sequences]) return; + if (token_idx >= subsequence_begins[batch_size_in_sequences]) return; uint subsequence_idx = 0; for (uint i = 0; i < batch_size_in_sequences; i++) { if (token_idx >= subsequence_begins[i] && token_idx < subsequence_begins[i + 1]) { @@ -87,28 +87,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const uint token_start_pos = (past_len + token_idx - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE; const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx; - if (token_idx >= subsequence_begins[batch_size_in_sequences]) { - #if KV_CACHE_COMPRESSION_PER_TOKEN - #else - // In PTL some V cache are written with NAN or random value due to unknown reason, while PA kernel will leverage lsc cm_load to - // load V cache by 16x16 block with vnni format, it is hard to exclude the unused V cache when NAN is involved in the same 16x16 block. - // Once NAN takes part in dpas, the NAN will propagate and cause result become NAN. - // As a WA, we need to set the unused part(in the same 16 row) of V cache to 0 here. - const uint last_token_idx = (past_len + 1) % PAGED_ATTENTION_BLOCK_SIZE; - - if (token_idx >= last_token_idx && token_idx < PAGED_ATTENTION_BLOCK_SIZE) { - uint block_k_base_offset = ((past_len + 1) / PAGED_ATTENTION_BLOCK_SIZE) * KV_HEADS_NUM * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; - uint key_out_offset = block_k_base_offset + head_idx * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + token_idx * ADJUSTED_K_HEAD_SIZE; - vector zero_data = 0; - - // Only reset unused part in the same 16 row for V cache. - // cm_ptr_store((int*)key_cache, key_out_offset * (int)sizeof(half), zero_data.format()); - cm_ptr_store((int*)value_cache, key_out_offset * (int)sizeof(half), zero_data.format()); - } - #endif - return; - } - #if KV_CACHE_COMPRESSION_PER_TOKEN // Assume: K_HEAD_SIZE == K_HEAD_SIZE auto quantize_and_store = [&](vector data, uchar* out, uint out_offset, uint token_pos) { diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index f25f571025b3f0..2cdf221c1e6042 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -219,6 +219,13 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } #else cm_load(Kt.format(), b2dK.set_block_y(kv_pos)); + if(kv_pos_end < kv_pos + KV_STEP) { + auto KmatRef = Kt.format(); + uint valid_cols = kv_pos_end - kv_pos; + uint valid_cols_vnni = valid_cols * 2; + for (int r = valid_cols_vnni; r < KV_STEP * 2; r++) + KmatRef.select(0,r) = 0.0f; + } #endif #else matrix temp; @@ -398,6 +405,16 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( prepack_to_VNNI_W2(VmatNormal, Vmat.format()); #else cm_load(Vmat[0].format(), b2dV.set_block_y(kv_pos)); + // somtimes KV cache would be filled with random Nan, so need to clean up the unused value data. + if(kv_pos_end - kv_pos < KV_STEP) { + auto VmatRef = Vmat[0].format(); + uint valid_rows = kv_pos_end - kv_pos; + uint valid_rows_vnni = (valid_rows+1)/2; + for (int r = valid_rows_vnni; r < KV_STEP / 2; r++) + VmatRef.row(r) = 0.f; + if (valid_rows % 2 == 1) + VmatRef.row(valid_rows_vnni-1).select(1) = 0.f; + } #endif #else matrix temp; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index af082f6e1e6775..84c58e595dcff6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -222,9 +222,9 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() const size_t kv_len = get_input_kv_len(params); const size_t kv_heads_num = desc->kv_heads_num; - const size_t wg_count = (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE; + const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; - wgs.global = {1, kv_heads_num, wg_count * PA_KV_CACHE_BLOCK_SIZE}; + wgs.global = {1, kv_heads_num, wg_count * WG_SIZE}; wgs.local = {1, 1, WG_SIZE}; auto& scalars = kd.params.scalars; From bbf17edd8a0b978bdef25df147d83e506677cdc8 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 14 Oct 2025 22:19:27 +0800 Subject: [PATCH 60/96] Clean code --- .../src/graph/impls/cm/paged_attention.cpp | 32 - .../unit/test_cases/xattention_gpu_test.cpp | 724 +++++++++--------- 2 files changed, 360 insertions(+), 396 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index a640dba5b04545..32a198dbdd2957 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -125,38 +125,6 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { // ov::util::save_binary(filename, lock.data(), output_mem->size()); // pa_id++; // } - { - cldnn::stream& stream = instance.get_network().get_stream(); - stream.finish(); - static uint32_t pa_id = 0; - std::cout << "finish xattn_estimate_find_block!\n"; - for (int index = 0; index < 5; index++) { - auto output_mem = instance.get_intermediates_memories()[4]; - mem_lock lock(output_mem, stream); - auto& layout = output_mem->get_layout(); - auto dims = layout.get_dims(); - size_t total_size = output_mem->size(); - - std::cout << "PA" << pa_id << " layout: rank=" << layout.get_rank() - << ", dims=["; - for (size_t r = 0; r < dims.size(); r++) { - std::cout << dims[r]; - if (r != dims.size() - 1) std::cout << ","; - } - std::cout << "], total_size=" << total_size << "\n"; - - size_t max_print = total_size; //std::min(100, total_size); - std::cout << "Data: "; - for (size_t i = 0; i < max_print; i++) { - if (i % 32 == 0) std::cout << std::endl; - std::cout << static_cast(lock.data()[i]) << " "; - } - if (total_size > max_print) std::cout << "..."; - std::cout << "\n"; - } - - pa_id++; - } // #endif res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; } diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index 9fe17c5274783b..3c970d2cdcad62 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -16,22 +16,354 @@ #include "paged_attention_gpu_test.hpp" #include "random_generator.hpp" #include "test_utils.h" -#include -#include - using namespace cldnn; using namespace ov::intel_gpu; using namespace ::tests; -namespace std { -template <> -struct hash { - uint64_t operator()(const ov::float16 __val) const { - return std::hash()(__val); +using Shape = std::vector; + +using CMXAttentionBlockIndex = std::pair; // .first is the *query* dimension block index, .second is *key* +using CMXAttentionRetainedBlockIndices = std::set; +using CMXAttentionRetainedBlockIndicesForAllHeads = std::vector; + +template +class CMXAttentionBlockSelector { +public: + CMXAttentionBlockSelector(double threshold, size_t block_size, size_t stride) : m_threshold(threshold), m_block_size(block_size), m_stride(stride) { + OPENVINO_ASSERT(m_block_size % m_stride == 0); + } + + void diagonal_reshape(const T* input_data, const Shape& input_shape, T* output_data, const Shape& out_shape, bool is_antidiagonal) { + OPENVINO_ASSERT(input_shape.size() == 3); + OPENVINO_ASSERT(out_shape.size() == 3); + OPENVINO_ASSERT(input_shape[0] == out_shape[0]); + OPENVINO_ASSERT(input_shape[1] % m_stride == 0); + OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); + OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); + + size_t num_stride_steps = input_shape[1] / m_stride; + for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) { + size_t head_offset = head_idx * input_shape[1] * input_shape[2]; + for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) { + for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) { + size_t input_offset = head_offset; + size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2]; + if (is_antidiagonal) { + input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2]; + } else { + input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2]; + } + std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T)); + } + } + } + } + + void diagonal_reshape_kdb1_no_batch(const T* input_data, + const std::vector& input_shape, // [H, Q_orig, dim] + T* output_data, + const std::vector& output_shape) { + size_t H = input_shape[0]; + size_t Q_orig = input_shape[1]; + size_t dim = input_shape[2]; + size_t Q_new = output_shape[1]; + + for (size_t h = 0; h < H; ++h) { + size_t head_in_offset = h * Q_orig * dim; + size_t head_out_offset = h * Q_new * m_stride * dim; + + for (size_t s = 0; s < m_stride; ++s) { + for (size_t q = 0; q < Q_new; ++q) { + size_t in_idx = head_in_offset + (m_stride - 1 - s + q * m_stride) * dim; + size_t out_idx = head_out_offset + q * m_stride * dim + s * dim; + std::memcpy(output_data + out_idx, input_data + in_idx, dim * sizeof(T)); + } + } + } + } + + void transpose_matmul_scale(const T* reshaped_query_data, + const T* reshaped_key_data, + const Shape& reshaped_query_shape, + const Shape& reshaped_key_shape, + T* out, + const Shape& out_shape) { + OPENVINO_ASSERT(reshaped_key_shape.size() == 3); + OPENVINO_ASSERT(reshaped_query_shape.size() == 3); + OPENVINO_ASSERT(reshaped_query_shape[0] == reshaped_key_shape[0]); + OPENVINO_ASSERT(reshaped_query_shape[2] == reshaped_key_shape[2]); + + OPENVINO_ASSERT(out_shape.size() == 3); + OPENVINO_ASSERT(out_shape[0] == reshaped_query_shape[0]); + OPENVINO_ASSERT(out_shape[1] == reshaped_query_shape[1]); + OPENVINO_ASSERT(out_shape[2] == reshaped_key_shape[1]); + + ov::reference::matmul(reshaped_query_data, reshaped_key_data, out, reshaped_query_shape, reshaped_key_shape, out_shape, false, true); + + size_t out_size = out_shape[0] * out_shape[1] * out_shape[2]; + + for (size_t i = 0; i < out_size; i++) { + out[i] = out[i] / std::sqrt(reshaped_query_shape[2] * m_stride); + } + } + + void softmax(const T* reshaped_qk_product_data, const Shape& reshaped_qk_product_shape, T* out, const Shape& out_shape) { + OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); + OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); + ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2}); + } + + void block_sum_attention_scores(const T* attention_scores_data, const Shape& attention_scores_shape, T* out, const Shape& out_shape) { + OPENVINO_ASSERT(attention_scores_shape.size() == 3); + size_t antidiagonals_per_xattention_block = m_block_size / m_stride; + OPENVINO_ASSERT(attention_scores_shape[1] % antidiagonals_per_xattention_block == 0); + OPENVINO_ASSERT(attention_scores_shape[2] % antidiagonals_per_xattention_block == 0); + + OPENVINO_ASSERT(out_shape[0] == attention_scores_shape[0]); + OPENVINO_ASSERT(out_shape[1] == attention_scores_shape[1] / antidiagonals_per_xattention_block); + OPENVINO_ASSERT(out_shape[2] == attention_scores_shape[2] / antidiagonals_per_xattention_block); + + std::memset(out, 0, out_shape[0] * out_shape[1] * out_shape[2] * sizeof(T)); + + for (size_t head_idx = 0; head_idx < attention_scores_shape[0]; head_idx++) { + size_t in_head_offset = head_idx * attention_scores_shape[1] * attention_scores_shape[2]; + size_t out_head_offset = head_idx * out_shape[1] * out_shape[2]; + for (size_t query_len_idx = 0; query_len_idx < attention_scores_shape[1]; query_len_idx++) { + for (size_t key_len_idx = 0; key_len_idx < attention_scores_shape[2]; key_len_idx++) { + size_t query_block_idx = query_len_idx / antidiagonals_per_xattention_block; + size_t key_block_idx = key_len_idx / antidiagonals_per_xattention_block; + auto target_block_sum_ptr = out + out_head_offset + query_block_idx * out_shape[2] + key_block_idx; + *target_block_sum_ptr += *(attention_scores_data + in_head_offset + query_len_idx * attention_scores_shape[2] + key_len_idx); + } + } + } + } + + CMXAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(T* blocked_attention_scores_data, const Shape& blocked_attention_scores_shape) { + OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, "Expected shape [num_heads, q_block_num, k_block_num]"); + + size_t num_heads = blocked_attention_scores_shape[0]; + size_t q_block_num = blocked_attention_scores_shape[1]; + size_t k_block_num = blocked_attention_scores_shape[2]; + + CMXAttentionRetainedBlockIndicesForAllHeads retval(num_heads); + + std::vector>> mask(num_heads, std::vector>(q_block_num, std::vector(k_block_num, false))); + + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { + for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { + size_t diagonal_k = q_block_idx; + if (diagonal_k < k_block_num) { + mask[head_idx][q_block_idx][diagonal_k] = true; + } + // Step1: Keep the first column + mask[head_idx][q_block_idx][0] = true; + + // Step2: Create other_values(masked_fill) + std::vector> other_values; + for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + if (mask[head_idx][q_block_idx][k_block_idx]) + continue; + size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; + other_values.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); + } + + // Step3: Sort other-values in descending order + std::sort(other_values.begin(), other_values.end(), [](const auto& a, const auto& b) { + return a.first > b.first; + }); + + // Step4: Create cumulative_sum_without_self,cat([0, diagonal_sum, sorted_values[:-1]]) + std::vector sorted_scores; + sorted_scores.push_back(0.0); + size_t offset_diag = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; + float diag_score = static_cast(blocked_attention_scores_data[offset_diag]); + float first_col_score = 0.0; + if (diagonal_k != 0) { + size_t offset_first = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + 0; + first_col_score = static_cast(blocked_attention_scores_data[offset_first]); + } + sorted_scores.push_back(diag_score + first_col_score); + + for (auto& p : other_values) { + sorted_scores.push_back(p.first); + } + if (q_block_idx == 0) { + sorted_scores.pop_back(); + } + + // Step5: Calculate cumsum_without_self: cumsum of right-shifted sorted_scores + std::vector cumsum_without_self(sorted_scores.size(), 0.0); + float running = 0.0; + for (size_t i = 0; i < sorted_scores.size(); ++i) { + cumsum_without_self[i] = running; + running += sorted_scores[i]; + } + + // Step6: Generate required_sum + size_t offset_row_start = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num; + float row_sum = 0.0; + for (size_t k = 0; k < k_block_num; k++) { + row_sum += static_cast(blocked_attention_scores_data[offset_row_start + k]); + } + float required_sum = row_sum * m_threshold; + + // Step7: Create index_mask + std::vector index_mask(cumsum_without_self.size(), false); + for (size_t i = 0; i < cumsum_without_self.size(); i++) { + index_mask[i] = (cumsum_without_self[i] < required_sum); + } + + // Step8: Ceate index + std::vector index(index_mask.size(), 0); + for (size_t i = 0; i < index_mask.size(); i++) { + if (index_mask[i]) { + if (i == 0) + index[i] = 0; + else if (i == 1) + index[i] = diagonal_k; + else if (i - 2 < other_values.size()) + index[i] = other_values[i - 2].second; + else + index[i] = 0; + } + } + + for (size_t i = 0; i < index.size(); i++) { + size_t k_block_idx = index[i]; + if (index_mask[i] && k_block_idx < k_block_num) { + mask[head_idx][q_block_idx][k_block_idx] = true; + } + } + for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { + if (mask[head_idx][q_block_idx][k_block_idx]) + retval[head_idx].insert({q_block_idx, k_block_idx}); + } + } + } + + return retval; } + + CMXAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, const Shape& query_shape, const T* key_data, const Shape& key_shape) { + OPENVINO_ASSERT(query_shape.size() == 3); + OPENVINO_ASSERT(key_shape.size() == 3); + OPENVINO_ASSERT(key_shape[0] == query_shape[0]); + OPENVINO_ASSERT(key_shape[2] == query_shape[2]); + OPENVINO_ASSERT(query_shape[1] % m_stride == 0); + OPENVINO_ASSERT(key_shape[1] % m_stride == 0); + OPENVINO_ASSERT(query_shape[1] % m_block_size == 0); + OPENVINO_ASSERT(key_shape[1] % m_block_size == 0); + + size_t chunk_size = query_shape[1]; + size_t k_len = key_shape[1]; + size_t head_dim = query_shape[2]; + size_t num_heads = query_shape[0]; + size_t k_num_to_pad = ((k_len + chunk_size - 1) / chunk_size) * chunk_size - k_len; + Shape pad_key_shape = {num_heads, k_len + k_num_to_pad, head_dim}; + auto pad_key_buf = allocate_buf(pad_key_shape); + + for (size_t h = 0; h < num_heads; h++) + for (size_t t = 0; t < k_len; t++) + for (size_t d = 0; d < head_dim; d++) { + size_t offset = h * (k_len + k_num_to_pad) * head_dim + t * head_dim + d; + size_t original_offset = h * k_len * head_dim + t * head_dim + d; + pad_key_buf.get()[offset] = key_data[original_offset]; + } + + size_t k_chunk_num = (k_len + k_num_to_pad) / chunk_size; + size_t offset_token_chunk_num = k_chunk_num - 1; + size_t reshaped_chunk_size = chunk_size / m_stride; + size_t k_reshaped_num_to_pad = k_num_to_pad / m_stride; + size_t k_reshaped_seq_len = (k_len + k_num_to_pad) / m_stride; + + Shape reshaped_query_shape = {num_heads, query_shape[1] / m_stride, head_dim * m_stride}; + auto q_buf = allocate_buf(reshaped_query_shape); + diagonal_reshape_kdb1_no_batch(query_data, query_shape, q_buf.get(), reshaped_query_shape); + Shape reshaped_key_shape = {num_heads, pad_key_shape[1] / m_stride, head_dim * m_stride}; + auto k_buf = allocate_buf(reshaped_key_shape); + diagonal_reshape(pad_key_buf.get(), pad_key_shape, k_buf.get(), reshaped_key_shape, false); + Shape transpose_matmul_scaled_shape = {num_heads, query_shape[1] / m_stride, pad_key_shape[1] / m_stride}; + auto qk_buf = allocate_buf(transpose_matmul_scaled_shape); + + transpose_matmul_scale(q_buf.get(), k_buf.get(), reshaped_query_shape, reshaped_key_shape, qk_buf.get(), transpose_matmul_scaled_shape); + q_buf.reset(); + k_buf.reset(); + + Shape causal_mask_shape = {num_heads, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num}; + auto causal_mask_buf = allocate_buf(causal_mask_shape); + std::fill(causal_mask_buf.get(), causal_mask_buf.get() + ov::shape_size(causal_mask_shape), T(0)); + if (k_reshaped_num_to_pad) { + for (size_t h = 0; h < num_heads; h++) + for (size_t q = 0; q < reshaped_chunk_size; q++) + for (size_t k = k_reshaped_seq_len - k_reshaped_num_to_pad; k < k_reshaped_seq_len; k++) { + size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num) + k; + + causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); + } + } + + size_t chunk_start = offset_token_chunk_num * reshaped_chunk_size; + size_t chunk_end = chunk_start + reshaped_chunk_size; + + for (size_t h = 0; h < num_heads; h++) { + for (size_t q = 0; q < reshaped_chunk_size; q++) { + for (size_t k = q + 1; k < reshaped_chunk_size; k++) { + size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num) + chunk_start + k; + causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); + } + } + } + + for (size_t h = 0; h < num_heads; h++) { + for (size_t q = 0; q < reshaped_chunk_size; q++) { + for (size_t k = chunk_end; k < reshaped_chunk_size * k_chunk_num; k++) { + size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num) + k; + causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); + } + } + } + size_t out_size = transpose_matmul_scaled_shape[0] * transpose_matmul_scaled_shape[1] * transpose_matmul_scaled_shape[2]; + + for (size_t i = 0; i < out_size; i++) { + qk_buf.get()[i] += causal_mask_buf.get()[i]; + } + + causal_mask_buf.reset(); + Shape attention_scores_shape = transpose_matmul_scaled_shape; + auto attn_score_buf = allocate_buf(attention_scores_shape); + softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); + qk_buf.reset(); + + size_t antidiagonals_per_xattention_block = m_block_size / m_stride; + Shape block_sum_shape = {attention_scores_shape[0], + attention_scores_shape[1] / antidiagonals_per_xattention_block, + attention_scores_shape[2] / antidiagonals_per_xattention_block}; + + auto block_sum_buf = allocate_buf(block_sum_shape); + block_sum_attention_scores(attn_score_buf.get(), attention_scores_shape, block_sum_buf.get(), block_sum_shape); + attn_score_buf.reset(); + auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); + block_sum_buf.reset(); + + return selected_block_indices; + } + + std::shared_ptr allocate_buf(const Shape& shape) { + return std::shared_ptr(new T[ov::shape_size(shape)]); + } + + size_t pad_to_block(size_t token_length) { + return (token_length + m_block_size - 1) / m_block_size * m_block_size; + } + + double m_threshold; + + size_t m_block_size; + + size_t m_stride; }; -} // namespace std struct xAttentionReference { xAttentionReference(PagedAttentionManager& pam) : pam(pam), test_engine(pam.test_engine), test_stream(pam.test_stream) {} @@ -93,169 +425,6 @@ struct xAttentionReference { } private: - void print_tensor(const std::vector& data, size_t heads, size_t rows, size_t cols, const std::string& name) { - std::cout << name << " (" << heads << "x" << rows << "x" << cols << "):\n"; - for (size_t h = 0; h < heads; h++) { - std::cout << " Head " << h << ":\n"; - for (size_t i = 0; i < rows; i++) { - for (size_t j = 0; j < cols; j++) { - std::cout << static_cast(data[h * rows * cols + i * cols + j]) << "\n"; - } - std::cout << "\n"; - } - } - } - - std::vector softmax_1(const std::vector& logits) { - std::vector out(logits.size()); - float max_val = *std::max_element(logits.begin(), logits.end()); - float sum = 0.0f; - for (float v : logits) - sum += std::exp(v - max_val); - for (size_t i = 0; i < logits.size(); i++) { - out[i] = static_cast(std::exp(logits[i] - max_val) / sum); - } - return out; - } - - std::vector safe_softmax(const std::vector& logits) { - std::vector probs(logits.size(), 0.0f); - float max_logit = -std::numeric_limits::infinity(); - for (float l : logits) - max_logit = std::max(max_logit, l); - if (std::isinf(max_logit)) - return probs; - - float sum_exp = 0.0f; - for (float l : logits) - sum_exp += std::exp(l - max_logit); - if (sum_exp == 0.0f) - return probs; - - for (size_t i = 0; i < logits.size(); ++i) - probs[i] = std::exp(logits[i] - max_logit) / sum_exp; - return probs; - } - - std::vector compute_sparse_causal_attention(const std::vector& Q_in, // [B, Tq, H, Dq] - const std::vector& K_in, // [B, Tk, H, Dk] - const std::vector& V_in, // [B, Tk, H, Dv] - size_t num_heads, - size_t num_queries, - size_t num_keys, - size_t qk_head_dim, - size_t v_head_dim, - const ov::reference::XAttentionRetainedBlockIndicesForAllHeads& retained_blocks_for_all_heads = {}, - float scale = 0.0f, - size_t block_size = 1) { - if (scale == 0.0f) - scale = 1.0f / std::sqrt(static_cast(qk_head_dim)); - - bool use_sparse = !retained_blocks_for_all_heads.empty(); - std::vector output(num_heads * num_queries * v_head_dim, ov::float16(0.0f)); - - std::cout << "---- compute_sparse_causal_attention ----\n"; - std::cout << "num_heads=" << num_heads << " num_queries=" << num_queries << " num_keys=" << num_keys << " qk_head_dim=" << qk_head_dim - << " v_head_dim=" << v_head_dim << " scale=" << scale << "\n"; - - // ======== permute Q,K,V from [B,T,H,D] → [H,T,D] ======== - std::vector Q(num_heads * num_queries * qk_head_dim); - std::vector K(num_heads * num_keys * qk_head_dim); - std::vector V(num_heads * num_keys * v_head_dim); - - for (size_t h = 0; h < num_heads; ++h) { - for (size_t t = 0; t < num_queries; ++t) { - for (size_t d = 0; d < qk_head_dim; ++d) { - Q[h * num_queries * qk_head_dim + t * qk_head_dim + d] = Q_in[t * num_heads * qk_head_dim + h * qk_head_dim + d]; - } - } - for (size_t t = 0; t < num_keys; ++t) { - for (size_t d = 0; d < qk_head_dim; ++d) { - K[h * num_keys * qk_head_dim + t * qk_head_dim + d] = K_in[t * num_heads * qk_head_dim + h * qk_head_dim + d]; - } - for (size_t d = 0; d < v_head_dim; ++d) { - V[h * num_keys * v_head_dim + t * v_head_dim + d] = V_in[t * num_heads * v_head_dim + h * v_head_dim + d]; - } - } - } - - // ======== Attention per head ======== - for (size_t h = 0; h < num_heads; ++h) { - const auto& retained_blocks = use_sparse ? retained_blocks_for_all_heads[h] : ov::reference::XAttentionRetainedBlockIndices{}; - - if (use_sparse) { - std::cout << "Head " << h << " retained blocks: "; - for (const auto& blk : retained_blocks) - std::cout << "(" << blk.first << "," << blk.second << ") "; - std::cout << std::endl; - } - - for (size_t q = 0; q < num_queries; ++q) { - std::vector logits(num_keys, -1e9f); - bool any_valid = false; - - for (size_t k = 0; k < num_keys; ++k) { - size_t q_block = q / block_size; - size_t k_block = k / block_size; - - if (use_sparse && retained_blocks.find({q_block, k_block}) == retained_blocks.end()) - continue; - if (k > q) - continue; // causal mask - - float score = 0.0f; - for (size_t d = 0; d < qk_head_dim; ++d) - score += static_cast(Q[h * num_queries * qk_head_dim + q * qk_head_dim + d]) * - static_cast(K[h * num_keys * qk_head_dim + k * qk_head_dim + d]); - logits[k] = score * scale; - any_valid = true; - } - - if (!any_valid) { - std::cout << "Head " << h << ", Query " << q << " has no valid keys -> zero output.\n"; - continue; - } - - auto probs = safe_softmax(logits); - - for (size_t d = 0; d < v_head_dim; ++d) { - float acc = 0.0f; - for (size_t k = 0; k <= q; ++k) { - if (use_sparse && retained_blocks.find({q / block_size, k / block_size}) == retained_blocks.end()) - continue; - acc += probs[k] * static_cast(V[h * num_keys * v_head_dim + k * v_head_dim + d]); - } - output[h * num_queries * v_head_dim + q * v_head_dim + d] = static_cast(acc); - } - } - } - - // ======== Debug summary ======== - std::cout << "Output preview (head0, first few queries):\n"; - for (size_t q = 0; q < std::min(4, num_queries); ++q) { - std::cout << " Q" << q << ": "; - for (size_t d = 0; d < std::min(8, v_head_dim); ++d) - std::cout << static_cast(output[q * v_head_dim + d]) << " "; - std::cout << "\n"; - } - - return output; - } - - -// 保存为二进制 .bin 文件 -void save_tensor_to_bin(const std::string& filename, const std::vector& data) { - std::ofstream file(filename, std::ios::out | std::ios::binary); - if (!file) { - std::cerr << "Failed to open " << filename << " for writing" << std::endl; - return; - } - file.write(reinterpret_cast(data.data()), data.size() * sizeof(ov::float16)); - file.close(); - std::cout << "[Info] Saved " << filename << " (" << data.size() << " elements)" << std::endl; -} - - std::pair, std::vector> run_reference(const std::vector& query_data, const std::vector& key_data, const std::vector& value_data, @@ -270,7 +439,6 @@ void save_tensor_to_bin(const std::string& filename, const std::vector(num_heads), static_cast(num_queries), static_cast(k_head_size)}; ov::Shape key_shape_3d = {static_cast(num_heads), static_cast(num_keys), static_cast(k_head_size)}; - ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; - // { - // ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); - // retained_blocks = selector.select_blocks(query_data_3d.data(), query_shape_3d, key_data_3d.data(), key_shape_3d); - - // std::cout << "=== C++ 选中 blocks ===" << std::endl; - // for (size_t h = 0; h < retained_blocks.size(); ++h) { - // std::cout << "Head " << h << " selected blocks: "; - // for (const auto& idx_pair : retained_blocks[h]) { - // std::cout << "(" << idx_pair.first << "," << idx_pair.second << ") "; - // } - // std::cout << std::endl; - // } - // } - - - if (num_queries < static_cast(block_size)) { - // Case 1: too few queries — skip block selection - std::cout << "[Info] num_queries < block_size, skip block selection." << std::endl; - } else { - // Case 2: handle non-divisible length via padding + CMXAttentionRetainedBlockIndicesForAllHeads retained_blocks; + if (num_queries >= static_cast(block_size)) { size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; - - if (padded_q != static_cast(num_queries) || padded_k != static_cast(num_keys)) { - std::cout << "[Info] Padding Q/K length for block alignment: " - << "Q " << num_queries << "→" << padded_q - << ", K " << num_keys << "→" << padded_k << std::endl; - } - - // Build padded buffers for selection std::vector query_padded(num_heads * padded_q * k_head_size, ov::float16(0)); std::vector key_padded(num_heads * padded_k * k_head_size, ov::float16(0)); @@ -359,23 +500,6 @@ void save_tensor_to_bin(const std::string& filename, const std::vector(num_heads), padded_q, static_cast(k_head_size)}; ov::Shape key_shape_padded = {static_cast(num_heads), padded_k, static_cast(k_head_size)}; - - // === Save padded Q/K for Python comparison === - save_tensor_to_bin("q_padded.bin", query_padded); - save_tensor_to_bin("k_padded.bin", key_padded); - - std::ofstream meta("meta.txt"); - meta << "num_heads=" << num_heads << "\n"; - meta << "padded_q=" << padded_q << "\n"; - meta << "padded_k=" << padded_k << "\n"; - meta << "k_head_size=" << k_head_size << "\n"; - meta << "block_size=" << block_size << "\n"; - meta << "stride=" << stride << "\n"; - meta << "threshold=" << threshold << "\n"; - meta.close(); - std::cout << "[Info] Saved meta.txt with shape info" << std::endl; - - std::vector query_padded_f32(query_padded.size()); std::vector key_padded_f32(key_padded.size()); for (size_t i = 0; i < query_padded.size(); ++i) @@ -383,49 +507,11 @@ void save_tensor_to_bin(const std::string& filename, const std::vector(key_padded[i]); - ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); + CMXAttentionBlockSelector selector(threshold, block_size, stride); retained_blocks = selector.select_blocks(query_padded_f32.data(), query_shape_padded, key_padded_f32.data(), key_shape_padded); - - std::cout << "=== Selected blocks after padding ===" << std::endl; - for (size_t h = 0; h < retained_blocks.size(); ++h) { - std::cout << "Head " << h << " selected blocks: "; - for (const auto& idx_pair : retained_blocks[h]) { - std::cout << "(" << idx_pair.first << "," << idx_pair.second << ") "; - } - std::cout << std::endl; - } } - - // retained_blocks = { - // { // Head 0 - // {0,0}, {1,0}, {1,1}, {2,0}, {2,1}, {2,2}, {3,0}, {3,1}, {3,2}, {3,3}, {4,0}, {4,1}, {4,2}, {4,3}, {4,4}, {5,0}, {5,1}, {5,2}, {5,3}, {5,4}, {5,5}, {6,0}, {6,1}, {6,2}, {6,3}, {6,4}, {6,5}, {6,6}, {7,0}, {7,1}, {7,2}, {7,3}, {7,4}, {7,5}, {7,6}, {7,7}, {8,0}, {8,1}, {8,2}, {8,3}, {8,4}, {8,5}, {8,6}, {8,7}, {8,8}, {9,0}, {9,1}, {9,2}, {9,3}, {9,4}, {9,5}, {9,6}, {9,7}, {9,8}, {9,9}, {10,0}, {10,1}, {10,2}, {10,3}, {10,5}, {10,6}, {10,7}, {10,8}, {10,9}, {10,10}, {11,0}, {11,1}, {11,2}, {11,3}, {11,4}, {11,5}, {11,7}, {11,8}, {11,9}, {11,10}, {11,11}, {12,0}, {12,1}, {12,2}, {12,3}, {12,4}, {12,5}, {12,6}, {12,7}, {12,8}, {12,9}, {12,10}, {12,12}, {13,0}, {13,1}, {13,2}, {13,3}, {13,4}, {13,5}, {13,6}, {13,7}, {13,8}, {13,9}, {13,11}, {13,12}, {13,13}, {14,0}, {14,1}, {14,2}, {14,3}, {14,4}, {14,5}, {14,6}, {14,8}, {14,9}, {14,10}, {14,11}, {14,12}, {14,13}, {14,14}, {15,0}, {15,1}, {15,2}, {15,3}, {15,4}, {15,5}, {15,6}, {15,8}, {15,9}, {15,10}, {15,11}, {15,12}, {15,13}, {15,14}, {15,15}, {16,0}, {16,1}, {16,2}, {16,3}, {16,4}, {16,5}, {16,6}, {16,7}, {16,8}, {16,9}, {16,10}, {16,11}, {16,12}, {16,13}, {16,14}, {16,16}, {17,0}, {17,1}, {17,2}, {17,3}, {17,4}, {17,5}, {17,6}, {17,7}, {17,8}, {17,9}, {17,10}, {17,11}, {17,12}, {17,13}, {17,14}, {17,15}, {17,17}, {18,0}, {18,1}, {18,2}, {18,3}, {18,4}, {18,5}, {18,6}, {18,7}, {18,9}, {18,10}, {18,11}, {18,12}, {18,13}, {18,14}, {18,15}, {18,16}, {18,17}, {18,18}, {19,0}, {19,1}, {19,3}, {19,4}, {19,5}, {19,6}, {19,7}, {19,8}, {19,9}, {19,10}, {19,11}, {19,12}, {19,13}, {19,14}, {19,15}, {19,16}, {19,17}, {19,18}, {19,19}, {20,0}, {20,1}, {20,2}, {20,3}, {20,4}, {20,5}, {20,6}, {20,7}, {20,9}, {20,10}, {20,12}, {20,13}, {20,14}, {20,15}, {20,16}, {20,17}, {20,18}, {20,19}, {20,20}, {21,0}, {21,1}, {21,2}, {21,3}, {21,4}, {21,5}, {21,6}, {21,7}, {21,8}, {21,9}, {21,10}, {21,12}, {21,13}, {21,14}, {21,16}, {21,17}, {21,18}, {21,19}, {21,20}, {21,21}, {22,0}, {22,1}, {22,2}, {22,3}, {22,4}, {22,5}, {22,7}, {22,8}, {22,10}, {22,11}, {22,12}, {22,13}, {22,14}, {22,15}, {22,16}, {22,17}, {22,18}, {22,19}, {22,20}, {22,21}, {22,22}, {23,0}, {23,1}, {23,2}, {23,3}, {23,4}, {23,5}, {23,6}, {23,7}, {23,9}, {23,10}, {23,11}, {23,13}, {23,14}, {23,15}, {23,16}, {23,17}, {23,18}, {23,19}, {23,20}, {23,21}, {23,22}, {23,23}, {24,0}, {24,2}, {24,3}, {24,4}, {24,5}, {24,6}, {24,7}, {24,9}, {24,10}, {24,11}, {24,12}, {24,13}, {24,14}, {24,15}, {24,16}, {24,17}, {24,18}, {24,19}, {24,20}, {24,21}, {24,22}, {24,23}, {24,24}, {25,0}, {25,1}, {25,2}, {25,3}, {25,4}, {25,5}, {25,6}, {25,7}, {25,8}, {25,9}, {25,10}, {25,11}, {25,12}, {25,13}, {25,14}, {25,15}, {25,16}, {25,17}, {25,19}, {25,20}, {25,22}, {25,23}, {25,24}, {25,25}, {26,0}, {26,1}, {26,2}, {26,3}, {26,4}, {26,5}, {26,6}, {26,7}, {26,8}, {26,9}, {26,10}, {26,12}, {26,13}, {26,14}, {26,15}, {26,16}, {26,17}, {26,18}, {26,19}, {26,20}, {26,21}, {26,23}, {26,24}, {26,25}, {26,26}, {27,0}, {27,1}, {27,3}, {27,4}, {27,5}, {27,6}, {27,7}, {27,8}, {27,9}, {27,10}, {27,12}, {27,13}, {27,14}, {27,15}, {27,16}, {27,17}, {27,18}, {27,19}, {27,20}, {27,21}, {27,22}, {27,23}, {27,24}, {27,25}, {27,26}, {27,27}, {28,0}, {28,1}, {28,2}, {28,3}, {28,4}, {28,5}, {28,6}, {28,7}, {28,8}, {28,9}, {28,11}, {28,12}, {28,13}, {28,14}, {28,15}, {28,16}, {28,17}, {28,18}, {28,19}, {28,20}, {28,21}, {28,22}, {28,23}, {28,24}, {28,25}, {28,26}, {28,28}, {29,0}, {29,1}, {29,2}, {29,3}, {29,4}, {29,5}, {29,6}, {29,7}, {29,8}, {29,9}, {29,11}, {29,12}, {29,13}, {29,14}, {29,15}, {29,17}, {29,18}, {29,19}, {29,20}, {29,21}, {29,22}, {29,23}, {29,24}, {29,25}, {29,26}, {29,27}, {29,29}, {30,0}, {30,1}, {30,2}, {30,3}, {30,4}, {30,5}, {30,6}, {30,7}, {30,8}, {30,9}, {30,10}, {30,13}, {30,14}, {30,15}, {30,16}, {30,17}, {30,18}, {30,19}, {30,20}, {30,21}, {30,23}, {30,24}, {30,25}, {30,26}, {30,27}, {30,28}, {30,29}, {30,30}, {31,0}, {31,1}, {31,2}, {31,4}, {31,5}, {31,6}, {31,8}, {31,9}, {31,10}, {31,11}, {31,12}, {31,13}, {31,14}, {31,15}, {31,16}, {31,17}, {31,18}, {31,19}, {31,20}, {31,21}, {31,22}, {31,23}, {31,25}, {31,26}, {31,27}, {31,28}, {31,29}, {31,30}, {31,31} - // }, - // { // Head 1 - // {0,0}, {1,0}, {1,1}, {2,0}, {2,1}, {2,2}, {3,0}, {3,1}, {3,2}, {3,3}, {4,0}, {4,1}, {4,2}, {4,3}, {4,4}, {5,0}, {5,1}, {5,2}, {5,3}, {5,4}, {5,5}, {6,0}, {6,1}, {6,2}, {6,3}, {6,4}, {6,5}, {6,6}, {7,0}, {7,1}, {7,2}, {7,3}, {7,4}, {7,5}, {7,6}, {7,7}, {8,0}, {8,1}, {8,2}, {8,3}, {8,4}, {8,5}, {8,6}, {8,7}, {8,8}, {9,0}, {9,1}, {9,2}, {9,3}, {9,4}, {9,5}, {9,6}, {9,7}, {9,8}, {9,9}, {10,0}, {10,1}, {10,2}, {10,3}, {10,4}, {10,6}, {10,7}, {10,8}, {10,9}, {10,10}, {11,0}, {11,1}, {11,2}, {11,3}, {11,4}, {11,5}, {11,6}, {11,8}, {11,9}, {11,10}, {11,11}, {12,0}, {12,1}, {12,2}, {12,3}, {12,5}, {12,6}, {12,7}, {12,8}, {12,9}, {12,10}, {12,11}, {12,12}, {13,0}, {13,1}, {13,2}, {13,3}, {13,4}, {13,5}, {13,6}, {13,8}, {13,9}, {13,10}, {13,11}, {13,12}, {13,13}, {14,0}, {14,1}, {14,2}, {14,3}, {14,4}, {14,5}, {14,6}, {14,8}, {14,9}, {14,10}, {14,11}, {14,12}, {14,13}, {14,14}, {15,0}, {15,1}, {15,2}, {15,3}, {15,4}, {15,5}, {15,6}, {15,7}, {15,8}, {15,9}, {15,10}, {15,11}, {15,12}, {15,14}, {15,15}, {16,0}, {16,1}, {16,2}, {16,3}, {16,4}, {16,5}, {16,7}, {16,8}, {16,9}, {16,10}, {16,11}, {16,12}, {16,13}, {16,14}, {16,15}, {16,16}, {17,0}, {17,2}, {17,3}, {17,4}, {17,5}, {17,6}, {17,7}, {17,8}, {17,9}, {17,10}, {17,11}, {17,12}, {17,13}, {17,14}, {17,15}, {17,16}, {17,17}, {18,0}, {18,1}, {18,2}, {18,3}, {18,4}, {18,5}, {18,6}, {18,7}, {18,8}, {18,9}, {18,10}, {18,11}, {18,12}, {18,13}, {18,14}, {18,15}, {18,17}, {18,18}, {19,0}, {19,1}, {19,2}, {19,3}, {19,4}, {19,5}, {19,6}, {19,7}, {19,8}, {19,9}, {19,10}, {19,11}, {19,12}, {19,13}, {19,15}, {19,16}, {19,17}, {19,18}, {19,19}, {20,0}, {20,1}, {20,2}, {20,3}, {20,4}, {20,5}, {20,6}, {20,7}, {20,8}, {20,10}, {20,11}, {20,12}, {20,13}, {20,14}, {20,15}, {20,16}, {20,18}, {20,19}, {20,20}, {21,0}, {21,1}, {21,2}, {21,4}, {21,5}, {21,6}, {21,7}, {21,9}, {21,10}, {21,11}, {21,12}, {21,13}, {21,14}, {21,15}, {21,16}, {21,17}, {21,18}, {21,19}, {21,20}, {21,21}, {22,0}, {22,1}, {22,2}, {22,3}, {22,4}, {22,5}, {22,7}, {22,8}, {22,9}, {22,10}, {22,11}, {22,12}, {22,13}, {22,14}, {22,15}, {22,16}, {22,17}, {22,18}, {22,20}, {22,21}, {22,22}, {23,0}, {23,1}, {23,2}, {23,3}, {23,5}, {23,6}, {23,7}, {23,8}, {23,9}, {23,10}, {23,11}, {23,12}, {23,13}, {23,14}, {23,15}, {23,16}, {23,18}, {23,19}, {23,20}, {23,21}, {23,22}, {23,23}, {24,0}, {24,1}, {24,2}, {24,3}, {24,4}, {24,5}, {24,6}, {24,7}, {24,9}, {24,10}, {24,11}, {24,13}, {24,14}, {24,15}, {24,16}, {24,17}, {24,18}, {24,19}, {24,20}, {24,21}, {24,22}, {24,23}, {24,24}, {25,0}, {25,1}, {25,2}, {25,3}, {25,4}, {25,5}, {25,6}, {25,7}, {25,8}, {25,10}, {25,11}, {25,12}, {25,13}, {25,14}, {25,15}, {25,16}, {25,17}, {25,18}, {25,19}, {25,20}, {25,21}, {25,22}, {25,24}, {25,25}, {26,0}, {26,1}, {26,2}, {26,3}, {26,4}, {26,5}, {26,6}, {26,7}, {26,8}, {26,9}, {26,10}, {26,11}, {26,12}, {26,13}, {26,15}, {26,16}, {26,17}, {26,18}, {26,19}, {26,20}, {26,21}, {26,23}, {26,24}, {26,25}, {26,26}, {27,0}, {27,1}, {27,2}, {27,3}, {27,4}, {27,5}, {27,6}, {27,7}, {27,8}, {27,9}, {27,10}, {27,11}, {27,12}, {27,13}, {27,14}, {27,16}, {27,17}, {27,18}, {27,19}, {27,20}, {27,22}, {27,23}, {27,24}, {27,25}, {27,26}, {27,27}, {28,0}, {28,1}, {28,2}, {28,3}, {28,4}, {28,5}, {28,6}, {28,7}, {28,8}, {28,9}, {28,11}, {28,12}, {28,13}, {28,14}, {28,15}, {28,16}, {28,17}, {28,18}, {28,19}, {28,20}, {28,21}, {28,22}, {28,24}, {28,25}, {28,26}, {28,27}, {28,28}, {29,0}, {29,1}, {29,2}, {29,3}, {29,4}, {29,5}, {29,7}, {29,8}, {29,9}, {29,11}, {29,12}, {29,13}, {29,14}, {29,15}, {29,16}, {29,17}, {29,18}, {29,19}, {29,20}, {29,21}, {29,22}, {29,23}, {29,24}, {29,25}, {29,27}, {29,28}, {29,29}, {30,0}, {30,1}, {30,2}, {30,3}, {30,4}, {30,6}, {30,7}, {30,8}, {30,9}, {30,10}, {30,11}, {30,12}, {30,14}, {30,15}, {30,16}, {30,17}, {30,18}, {30,19}, {30,20}, {30,21}, {30,22}, {30,24}, {30,25}, {30,26}, {30,27}, {30,28}, {30,29}, {30,30}, {31,0}, {31,1}, {31,2}, {31,3}, {31,4}, {31,5}, {31,6}, {31,7}, {31,9}, {31,10}, {31,11}, {31,12}, {31,14}, {31,15}, {31,16}, {31,17}, {31,18}, {31,19}, {31,20}, {31,21}, {31,23}, {31,24}, {31,25}, {31,26}, {31,27}, {31,28}, {31,29}, {31,30}, {31,31} - // } - // }; - -// retained_blocks = {{ // Head 0 -// {0,0}, {1,0}, {1,1}, {2,0}, {2,1}, {2,2}, {3,0}, {3,1}, {3,2}, {3,3}, {4,0}, {4,1}, {4,2}, {4,3}, {4,4}, {5,0}, {5,1}, {5,2}, {5,3}, {5,4}, {5,5}, {6,0}, {6,1}, {6,2}, {6,3}, {6,4}, {6,5}, {6,6}, {7,0}, {7,1}, {7,2}, {7,3}, {7,4}, {7,5}, {7,6}, {7,7}, {8,0}, {8,1}, {8,2}, {8,3}, {8,4}, {8,5}, {8,6}, {8,7}, {8,8}, {9,0}, {9,1}, {9,2}, {9,3}, {9,4}, {9,5}, {9,6}, {9,7}, {9,8}, {9,9}, {10,0}, {10,1}, {10,2}, {10,3}, {10,4}, {10,5}, {10,6}, {10,7}, {10,9}, {10,10}, {11,0}, {11,1}, {11,2}, {11,3}, {11,4}, {11,5}, {11,6}, {11,7}, {11,9}, {11,10}, {11,11}, {12,0}, {12,1}, {12,2}, {12,3}, {12,4}, {12,6}, {12,7}, {12,8}, {12,9}, {12,10}, {12,11}, {12,12}, {13,0}, {13,2}, {13,3}, {13,4}, {13,5}, {13,6}, {13,7}, {13,8}, {13,9}, {13,10}, {13,11}, {13,12}, {13,13}, {14,0}, {14,1}, {14,2}, {14,3}, {14,4}, {14,5}, {14,7}, {14,8}, {14,9}, {14,10}, {14,11}, {14,12}, {14,13}, {14,14}, {15,0}, {15,1}, {15,2}, {15,3}, {15,5}, {15,6}, {15,7}, {15,8}, {15,9}, {15,10}, {15,11}, {15,12}, {15,13}, {15,14}, {15,15}, {16,0}, {16,1}, {16,2}, {16,3}, {16,4}, {16,5}, {16,6}, {16,7}, {16,8}, {16,9}, {16,10}, {16,12}, {16,13}, {16,14}, {16,15}, {16,16}, {17,0}, {17,1}, {17,2}, {17,4}, {17,5}, {17,6}, {17,7}, {17,8}, {17,9}, {17,10}, {17,11}, {17,12}, {17,13}, {17,14}, {17,15}, {17,16}, {17,17}, {18,0}, {18,1}, {18,2}, {18,3}, {18,4}, {18,5}, {18,7}, {18,8}, {18,9}, {18,10}, {18,11}, {18,12}, {18,13}, {18,14}, {18,15}, {18,16}, {18,17}, {18,18}, {19,0}, {19,1}, {19,2}, {19,4}, {19,5}, {19,6}, {19,7}, {19,8}, {19,9}, {19,10}, {19,11}, {19,12}, {19,13}, {19,14}, {19,15}, {19,16}, {19,17}, {19,18}, {19,19}, {20,0}, {20,1}, {20,2}, {20,3}, {20,4}, {20,5}, {20,6}, {20,7}, {20,9}, {20,10}, {20,11}, {20,13}, {20,14}, {20,15}, {20,16}, {20,17}, {20,18}, {20,19}, {20,20}, {21,0}, {21,1}, {21,2}, {21,3}, {21,5}, {21,6}, {21,7}, {21,8}, {21,9}, {21,10}, {21,11}, {21,12}, {21,14}, {21,15}, {21,16}, {21,17}, {21,18}, {21,19}, {21,20}, {21,21}, {22,0}, {22,1}, {22,2}, {22,3}, {22,4}, {22,5}, {22,7}, {22,8}, {22,9}, {22,10}, {22,11}, {22,12}, {22,14}, {22,15}, {22,16}, {22,17}, {22,18}, {22,19}, {22,20}, {22,21}, {22,22}, {23,0}, {23,1}, {23,2}, {23,3}, {23,4}, {23,5}, {23,7}, {23,8}, {23,9}, {23,10}, {23,11}, {23,12}, {23,13}, {23,14}, {23,16}, {23,17}, {23,18}, {23,19}, {23,20}, {23,21}, {23,22}, {23,23}, {24,0}, {24,1}, {24,2}, {24,3}, {24,5}, {24,6}, {24,7}, {24,8}, {24,10}, {24,11}, {24,12}, {24,13}, {24,14}, {24,15}, {24,16}, {24,17}, {24,18}, {24,19}, {24,20}, {24,21}, {24,22}, {24,23}, {24,24}, {25,0}, {25,1}, {25,2}, {25,3}, {25,6}, {25,7}, {25,8}, {25,9}, {25,10}, {25,11}, {25,12}, {25,13}, {25,14}, {25,15}, {25,16}, {25,17}, {25,18}, {25,19}, {25,20}, {25,21}, {25,22}, {25,23}, {25,24}, {25,25}, {26,0}, {26,1}, {26,2}, {26,3}, {26,4}, {26,5}, {26,6}, {26,7}, {26,8}, {26,9}, {26,10}, {26,11}, {26,12}, {26,13}, {26,14}, {26,15}, {26,16}, {26,17}, {26,18}, {26,20}, {26,21}, {26,22}, {26,23}, {26,24}, {26,26}, {27,0}, {27,1}, {27,2}, {27,3}, {27,4}, {27,7}, {27,8}, {27,9}, {27,10}, {27,11}, {27,12}, {27,13}, {27,14}, {27,15}, {27,16}, {27,17}, {27,18}, {27,19}, {27,20}, {27,21}, {27,22}, {27,23}, {27,24}, {27,25}, {27,26}, {27,27}, {28,0}, {28,1}, {28,2}, {28,3}, {28,4}, {28,5}, {28,6}, {28,7}, {28,9}, {28,10}, {28,11}, {28,12}, {28,13}, {28,14}, {28,15}, {28,16}, {28,17}, {28,18}, {28,19}, {28,20}, {28,21}, {28,22}, {28,23}, {28,25}, {28,26}, {28,27}, {28,28}, {29,0}, {29,1}, {29,2}, {29,3}, {29,4}, {29,5}, {29,6}, {29,7}, {29,8}, {29,9}, {29,10}, {29,11}, {29,12}, {29,13}, {29,14}, {29,15}, {29,16}, {29,17}, {29,18}, {29,19}, {29,22}, {29,23}, {29,24}, {29,25}, {29,26}, {29,27}, {29,28}, {29,29}, {30,0}, {30,1}, {30,2}, {30,3}, {30,4}, {30,5}, {30,6}, {30,7}, {30,8}, {30,9}, {30,10}, {30,11}, {30,12}, {30,13}, {30,14}, {30,15}, {30,16}, {30,17}, {30,18}, {30,19}, {30,20}, {30,22}, {30,23}, {30,24}, {30,25}, {30,27}, {30,28}, {30,30}, {31,0}, {31,1}, {31,2}, {31,3}, {31,4}, {31,5}, {31,6}, {31,8}, {31,9}, {31,10}, {31,11}, {31,12}, {31,13}, {31,14}, {31,15}, {31,16}, {31,17}, {31,18}, {31,19}, {31,20}, {31,22}, {31,23}, {31,24}, {31,25}, {31,27}, {31,28}, {31,29}, {31,30}, {31,31} -// }}; - - - // auto output = compute_sparse_causal_attention(query_data, - // key_data, - // value_data, - // num_heads, - // num_queries, - // num_keys, - // k_head_size, - // v_head_size, - // retained_blocks, - // 0.0f, - // block_size); - - // print_tensor(output, num_heads, num_queries, k_head_size, "Output"); auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, block_size); - // auto mask_mem = get_mask_mem(num_queries, num_keys, num_heads, sliding_window_size); topology topology; topology.add(input_layout("query", query_layout), @@ -493,12 +579,9 @@ void save_tensor_to_bin(const std::string& filename, const std::vector mem_ptr(mask_mem, test_stream); - - if (sliding_window_size == 0) { - int past_len = num_keys - num_queries + 1; - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest() - : ov::float16(0.f); - } - } - } else { - int sliding_left = num_keys - num_queries - sliding_window_size + 1; - int past_len = num_keys - num_queries + 1; - - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - bool is_min; - if (num_queries == num_keys) { - is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; - } else { - is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; - } - - mem_ptr[i * num_keys + j] = is_min ? std::numeric_limits::lowest() : ov::float16(0.f); - } - } - } - - return mask_mem; - } - - void rotate_block(std::vector& cache_data, std::vector rotation_deltas, std::vector rotation_trig_lut_mem, @@ -932,9 +944,6 @@ struct xAttentionTest : public ::testing::TestWithParam { output_scores_mem = outputs.at("output_scores").get_memory(); } auto ref_data = xAttentionReference(pam).get_reference(); - // for (size_t i = 0; i < ref_data.first.size(); i++) { - // std::cout << i << "reference = " << ref_data.first[i] << std::endl; - // } compare(output_data_mem, output_scores_mem, ref_data); } @@ -942,27 +951,25 @@ struct xAttentionTest : public ::testing::TestWithParam { if (data_output_mem) { ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); mem_lock mem_ptr(data_output_mem, get_test_stream()); + int mismatch_count = 0; for (size_t i = 0; i < data_output_mem->count(); i++) { - std::cout << i << ": result = " << mem_ptr[i] << ", reference = " << ref_data.first[i] << std::endl; - } - std::cout << "data_output_mem->count(): " << data_output_mem->count() << std::endl; - int num = 0; - for (size_t i = 0; i < data_output_mem->count(); i++) { - if (abs(mem_ptr[i] - ref_data.first[i]) > tolerance) { - // std::cout << "mem_ptr: " << mem_ptr[i] << " " << "ref_data: " << ref_data.first[i] << std::endl; - num++; + if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.first[i])) > tolerance) { + mismatch_count++; } - // ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance) << " at index=" << i; } - std::cout << "num: " << num << std::endl; + EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.02)); } if (scores_output_mem) { ASSERT_EQ(scores_output_mem->count(), ref_data.second.size()); mem_lock mem_ptr(scores_output_mem, get_test_stream()); + int mismatch_count = 0; for (size_t i = 0; i < scores_output_mem->count(); i++) { - ASSERT_NEAR(mem_ptr[i], ref_data.second[i], tolerance) << " at index=" << i; + if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.second[i])) > tolerance) { + mismatch_count++; + } } + EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.02)); } } }; @@ -1002,28 +1009,17 @@ const auto DYNAMIC_INPUT_PAD = true; const auto ENABLE_FA_V2 = false; const auto DISABLE_FA_V2 = true; -INSTANTIATE_TEST_SUITE_P(smoke_xattention, +INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test, ::testing::ValuesIn(std::vector{ #if ENABLE_PA_CM_PATH /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ - // xattention_test_params{ {{32, 0}}, 2, 2, 2, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - xattention_test_params{ {{4096, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - // xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - -// xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, -// DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - -// xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, -// DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, -// ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, -// 1023}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, -// DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 127}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, -// STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 129}}, 2, 64, 64, 256, 0, -// DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token -// xattention_test_params{ {{1, 32}}, 28, 128, 128, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, -// DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{32, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + xattention_test_params{ {{4096, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + + xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token #endif })); From 147063f22521d4f5353a69722a9fc9b300a53bf6 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 14 Oct 2025 22:20:44 +0800 Subject: [PATCH 61/96] Clean code --- .../include/openvino/reference/x.bakup.cpp | 528 ------------------ 1 file changed, 528 deletions(-) delete mode 100644 src/core/reference/include/openvino/reference/x.bakup.cpp diff --git a/src/core/reference/include/openvino/reference/x.bakup.cpp b/src/core/reference/include/openvino/reference/x.bakup.cpp deleted file mode 100644 index 9e69cfabcf1816..00000000000000 --- a/src/core/reference/include/openvino/reference/x.bakup.cpp +++ /dev/null @@ -1,528 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "openvino/reference/divide.hpp" -#include "openvino/reference/matmul.hpp" -#include "openvino/reference/softmax.hpp" -#include "openvino/reference/transpose.hpp" -#include "openvino/runtime/tensor.hpp" - -namespace ov::reference { - -using XAttentionBlockIndex = - std::pair; // .first is the *query* dimension block index, .second is *key* -using XAttentionRetainedBlockIndices = std::set; -using XAttentionRetainedBlockIndicesForAllHeads = std::vector; - -/** @brief Reference implementation of the XAttention sparse attention prefill mechanism - * (https://arxiv.org/abs/2503.16428) */ -template -class XAttentionBlockSelector { -public: - /** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the - * smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of - * attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of - * the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 - * corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained. - * @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension, - * key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks - * according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with - * the same granularity, i.e. the resulting blocks have equal size on both dimensions. Essentially `block_size` - * defines the granularity of the eventual sparse attention computations. Must be a multiple of `stride`. - * @param stride The stride at which the full attention matrix is subsampled in a block-antidiagonal fashion to - * estimate the block importance. Note that the full attention matrix is not computed, instead the original query - * and key matrices are reshaped appropriately so that only the necessary elements are computed. Ideally, the - * computational complexity of the entire block estimation operation is `stride` times lower than the full attention - * matrix computation. - * */ - XAttentionBlockSelector(double threshold, size_t block_size, size_t stride) - : m_threshold(threshold), - m_block_size(block_size), - m_stride(stride) { - OPENVINO_ASSERT(m_block_size % m_stride == 0); - } - - /** Assuming the input tensor is either a query tensor or key tensor, reshapes it in a diagonal or antidiagonal - * fashion as appropriate so that the resulting matrices could be used to compute the block-antidiagonal subset of - * the attention matrix in further operations. For the query tensor, the antidiagonal reshaping should be applied, - * and diagonal - for the key tensor. Note that for the diagonal reshaping the data layout is effectively unchanged - * and only the shape can be adjusted in the efficient implementation of the same operation in HW. - * @param input_data Pointer to the input tensor data (query or key) - * @param input_shape Shape of the input tensor data (query or key). Expected shape is [num_heads, num_tokens, - * head_size], where `num_tokens` must be a multiple of `stride`. - * @param output_data Pointer to the output tensor data (reshaped query or key storage) - * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_tokens / stride, head_size * - * stride] - * @param is_antidiagonal Whether to reshape antidiagonally (true) or diagonally (false). Use `true` for query - * tensor and `false` for key tensor. - */ - void diagonal_reshape(const T* input_data, - const Shape& input_shape, - T* output_data, - const Shape& out_shape, - bool is_antidiagonal) { - OPENVINO_ASSERT(input_shape.size() == 3); // [num_heads, num_tokens, head_size] - OPENVINO_ASSERT(out_shape.size() == 3); - OPENVINO_ASSERT(input_shape[0] == out_shape[0]); - OPENVINO_ASSERT(input_shape[1] % m_stride == 0); - OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); - OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); - - size_t num_stride_steps = input_shape[1] / m_stride; - for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) { - size_t head_offset = head_idx * input_shape[1] * input_shape[2]; - for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) { - for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) { - size_t input_offset = head_offset; - size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2]; - if (is_antidiagonal) { - input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2]; - } else { - input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2]; - } - std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T)); - } - } - } - } - - /** Performs a matrix multiplication on the input tensors Q and K and scales the result in a typical attention op - * fashion, i.e. Q @ K^T / (sqrt(D) * S). Additionally rescales by the stride value, as compared to the regular - * attention. - * @param reshaped_query_data Pointer to the reshaped query input. - * @param reshaped_key_data Pointer to the reshaped key input. - * @param reshaped_query_shape Shape of the reshaped query input data. Expected shape is [num_heads, - * num_query_tokens / stride, head_size * stride]. - * @param reshaped_key_shape Shape of the reshaped key input data. Expected shape is [num_heads, num_key_tokens / - * stride, head_size * stride]. - * @param out Pointer to the output tensor data (attention logit scores) - * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / stride, - * num_key_tokens / stride] - */ - void transpose_matmul_scale(const T* reshaped_query_data, - const T* reshaped_key_data, - const Shape& reshaped_query_shape, - const Shape& reshaped_key_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(reshaped_key_shape.size() == 3); - OPENVINO_ASSERT(reshaped_query_shape.size() == 3); - OPENVINO_ASSERT(reshaped_query_shape[0] == reshaped_key_shape[0]); - OPENVINO_ASSERT(reshaped_query_shape[2] == reshaped_key_shape[2]); - - OPENVINO_ASSERT(out_shape.size() == 3); - OPENVINO_ASSERT(out_shape[0] == reshaped_query_shape[0]); - OPENVINO_ASSERT(out_shape[1] == reshaped_query_shape[1]); - OPENVINO_ASSERT(out_shape[2] == reshaped_key_shape[1]); - - ov::reference::matmul(reshaped_query_data, - reshaped_key_data, - out, - reshaped_query_shape, - reshaped_key_shape, - out_shape, - /* transpose_arg0 = */ false, - /* transpose_arg1 = */ true); - - size_t out_size = out_shape[0] * out_shape[1] * out_shape[2]; - - for (size_t i = 0; i < out_size; i++) { - // The D in the formula above refers to the original head dimension, while - // reshaped_query_shape[2] had been scaled in the process of reshaping, therefore - // the formula is also adjusted: - out[i] = out[i] / std::sqrt(reshaped_query_shape[2] * m_stride); - } - } - - /** Performs a softmax operation on the last dimension of the rank-3 input tensor. - * @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax). - * @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens / - * stride, num_key_tokens / stride]. - * @param out Pointer to the output tensor data (attention scores) - * @param out_shape Shape of the output tensor data. Expected shape is strictly equal to - * `reshaped_qk_product_shape`. - */ - void softmax(const T* reshaped_qk_product_data, - const Shape& reshaped_qk_product_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); - OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); - ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2}); - } - - /** Divides the input rank-3 tensor into blocks along last two dimensions, performs the addition of the values - * inside each block and outputs each block sum into corresponding positions in the output tensor downsampled along - * the same dimensions. The output tensor dimensions are such that the query and key token dimensions are - * downsampled by `block_size` when compared to the *original* query and key tensors. - * @param attention_scores_data Pointer to the attention score input. - * @param attention_score_shape Shape of the attention score input tensor. Expected shape is [num_heads, - * num_query_tokens / stride, num_key_tokens / stride], where `num_query_tokens` and `num_key_tokens` must be - * multiples of `block_size`. - * @param out Pointer to the output tensor data (block sums) - * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size, - * num_key_tokens / block_size]. - */ - void block_sum_attention_scores(const T* attention_scores_data, - const Shape& attention_scores_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(attention_scores_shape.size() == 3); // [num_heads, query_antidiagonals, key_antidiagonals] - size_t antidiagonals_per_xattention_block = m_block_size / m_stride; - OPENVINO_ASSERT(attention_scores_shape[1] % antidiagonals_per_xattention_block == 0); - OPENVINO_ASSERT(attention_scores_shape[2] % antidiagonals_per_xattention_block == 0); - - OPENVINO_ASSERT(out_shape[0] == attention_scores_shape[0]); - OPENVINO_ASSERT(out_shape[1] == - attention_scores_shape[1] / antidiagonals_per_xattention_block); // query length, blocked - OPENVINO_ASSERT(out_shape[2] == - attention_scores_shape[2] / antidiagonals_per_xattention_block); // key length, blocked - - std::memset(out, 0, out_shape[0] * out_shape[1] * out_shape[2] * sizeof(T)); - - for (size_t head_idx = 0; head_idx < attention_scores_shape[0]; head_idx++) { - size_t in_head_offset = head_idx * attention_scores_shape[1] * attention_scores_shape[2]; - size_t out_head_offset = head_idx * out_shape[1] * out_shape[2]; - for (size_t query_len_idx = 0; query_len_idx < attention_scores_shape[1]; query_len_idx++) { - for (size_t key_len_idx = 0; key_len_idx < attention_scores_shape[2]; key_len_idx++) { - size_t query_block_idx = query_len_idx / antidiagonals_per_xattention_block; - size_t key_block_idx = key_len_idx / antidiagonals_per_xattention_block; - auto target_block_sum_ptr = out + out_head_offset + query_block_idx * out_shape[2] + key_block_idx; - *target_block_sum_ptr += *(attention_scores_data + in_head_offset + - query_len_idx * attention_scores_shape[2] + key_len_idx); - } - } - } - } - - /** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, - * so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the - * total element sum. - * @param blocked_scores_data Pointer to the blocked score input. - * @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads, - * num_query_tokens / block_size, num_key_tokens / block_size] - * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block - * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks - * corresponding to the property described above. - */ -// template -void print_blocked_attention_scores(const T* data, - size_t num_heads, - size_t num_q_blocks, - size_t num_k_blocks) { - std::cout << "blocked_attention_scores shape: [" - << num_heads << ", " << num_q_blocks << ", " << num_k_blocks << "]\n"; - - for (size_t h = 0; h < num_heads; ++h) { - std::cout << "Head " << h << ":\n"; - std::cout << std::setw(8) << ""; - for (size_t k = 0; k < num_k_blocks; ++k) { - std::cout << std::setw(12) << ("K" + std::to_string(k)); - } - std::cout << "\n"; - - for (size_t q = 0; q < num_q_blocks; ++q) { - std::cout << std::setw(6) << ("Q" + std::to_string(q)) << " "; - double row_sum = 0.0; - for (size_t k = 0; k < num_k_blocks; ++k) { - size_t idx = h * (num_q_blocks * num_k_blocks) + q * num_k_blocks + k; - double v = static_cast(static_cast(*(data + idx))); - row_sum += v; - std::cout << std::setw(12) << std::fixed << std::setprecision(6) << v; - } - std::cout << " sum=" << std::fixed << std::setprecision(6) << row_sum << "\n"; - } - std::cout << std::flush; - } -} -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( -// const T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); -// // [num_heads, num_blocks_in_query, num_blocks_in_key] - -// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// struct IndexAndScore { -// XAttentionBlockIndex idx; -// T score; -// }; - -// const size_t num_heads = blocked_attention_scores_shape[0]; -// const size_t num_q_blocks = blocked_attention_scores_shape[1]; -// const size_t num_k_blocks = blocked_attention_scores_shape[2]; -// // print_blocked_attention_scores(blocked_attention_scores_data, num_heads, num_q_blocks, num_k_blocks); - -// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { -// size_t head_offset = head_idx * num_q_blocks * num_k_blocks; - -// for (size_t q_block_idx = 0; q_block_idx < num_q_blocks; q_block_idx++) { -// std::vector indices_and_scores; -// indices_and_scores.reserve(num_k_blocks); - -// double total_sum = 0.0; - -// for (size_t k_block_idx = 0; k_block_idx < num_k_blocks; k_block_idx++) { -// size_t target_offset = head_offset + q_block_idx * num_k_blocks + k_block_idx; -// T current_score = *(blocked_attention_scores_data + target_offset); -// indices_and_scores.push_back({{q_block_idx, k_block_idx}, current_score}); -// total_sum += current_score; -// } - -// double required_sum = m_threshold * total_sum; - -// std::sort(indices_and_scores.begin(), indices_and_scores.end(), -// [](const IndexAndScore& a, const IndexAndScore& b) { -// return a.score > b.score; -// }); - -// std::vector shifted_cumsum(num_k_blocks, 0.0); -// for (size_t i = 1; i < num_k_blocks; i++) { -// shifted_cumsum[i] = shifted_cumsum[i - 1] + indices_and_scores[i - 1].score; -// } - -// for (size_t i = 0; i < num_k_blocks; i++) { -// if (shifted_cumsum[i] < required_sum) { -// retval[head_idx].insert(indices_and_scores[i].idx); -// } -// } -// } -// } - -// return retval; -// } - - - -void dump_blocked_attention_scores_bin(const std::string& filename, - const float* data, - size_t num_heads, - size_t num_q_blocks, - size_t num_k_blocks) { - size_t total_elems = num_heads * num_q_blocks * num_k_blocks; - std::ofstream ofs(filename, std::ios::binary); - if (!ofs) { - std::cerr << "Failed to open file for writing: " << filename << std::endl; - return; - } - ofs.write(reinterpret_cast(data), total_elems * sizeof(float)); - ofs.close(); - - std::cout << "✅ Dumped blocked_attention_scores to: " << filename - << " (" << total_elems << " elements, " - << sizeof(float) * total_elems / 1024.0 << " KB)\n"; -} - -// template -XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( - const T* blocked_attention_scores_data, - const Shape& blocked_attention_scores_shape) { - OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); - // [num_heads, num_blocks_in_query, num_blocks_in_key] - - auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - - struct IndexAndScore { - XAttentionBlockIndex idx; - T score; - }; - - const size_t num_heads = blocked_attention_scores_shape[0]; - const size_t num_q_blocks = blocked_attention_scores_shape[1]; - const size_t num_k_blocks = blocked_attention_scores_shape[2]; - print_blocked_attention_scores(blocked_attention_scores_data, num_heads, num_q_blocks, num_k_blocks); - - size_t total_elems = num_heads * num_q_blocks * num_k_blocks; - std::vector data_f32(total_elems); - for (size_t i = 0; i < total_elems; i++) - data_f32[i] = static_cast(blocked_attention_scores_data[i]); - dump_blocked_attention_scores_bin("blocked_attention_scores.bin", - data_f32.data(), num_heads, num_q_blocks, num_k_blocks); - for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { - size_t head_offset = head_idx * num_q_blocks * num_k_blocks; - - for (size_t q_block_idx = 0; q_block_idx < num_q_blocks; q_block_idx++) { - std::vector indices_and_scores; - indices_and_scores.reserve(num_k_blocks); - - double total_sum = 0.0; - for (size_t k_block_idx = 0; k_block_idx < num_k_blocks; k_block_idx++) { - size_t offset = head_offset + q_block_idx * num_k_blocks + k_block_idx; - T score = *(blocked_attention_scores_data + offset); - indices_and_scores.push_back({{q_block_idx, k_block_idx}, score}); - total_sum += score; - } - - double required_sum = m_threshold * total_sum; - - // === 与 Python 一致:按 score 降序排序 === - std::sort(indices_and_scores.begin(), indices_and_scores.end(), - [](const IndexAndScore& a, const IndexAndScore& b) { - return a.score > b.score; - }); - - // === 模拟 Python 的 cumulative_sum_without_self === - // 即:每个元素的累积和是“之前所有元素的和”,自身不计入。 - std::vector shifted_cumsum(num_k_blocks, 0.0); - for (size_t i = 1; i < num_k_blocks; i++) { - shifted_cumsum[i] = shifted_cumsum[i - 1] + indices_and_scores[i - 1].score; - } - - // === 选择 cumulative_sum_without_self < required_sum 的 block === - for (size_t i = 0; i < num_k_blocks; i++) { - if (shifted_cumsum[i] < required_sum) { - retval[head_idx].insert(indices_and_scores[i].idx); - } - } - - // ✅ Python 中通常也会强制保留“自身 block”,即 (q_block_idx, q_block_idx) - retval[head_idx].insert({q_block_idx, q_block_idx}); - } - } - - return retval; -} - - - // XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, - // const Shape& blocked_attention_scores_shape) { - // OPENVINO_ASSERT(blocked_attention_scores_shape.size() == - // 3); // [num_heads, num_blocks_in_query, num_blocks_in_key] - - // auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - - // struct IndexAndScore { - // XAttentionBlockIndex idx; - // T score; - // bool operator<(const IndexAndScore& rhs) const { - // return score < rhs.score; - // } - // }; - - // for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) { - // size_t head_offset = head_idx * blocked_attention_scores_shape[1] * blocked_attention_scores_shape[2]; - // std::priority_queue indices_and_scores_queue; - // double total_sum = 0.0; - // for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) { - - // for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) { - // size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx; - // T current_score = *(blocked_attention_scores_data + target_offset); - // indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score}); - // total_sum += current_score; - // } - // } - // double cumsum = 0.0; - // double required_sum = m_threshold * total_sum; - // while (cumsum < required_sum && !indices_and_scores_queue.empty()) { - // auto index_and_largest_score = indices_and_scores_queue.top(); - // indices_and_scores_queue.pop(); - // cumsum += index_and_largest_score.score; - // retval[head_idx].insert(index_and_largest_score.idx); - // } - // } - // return retval; - // } - - /** Applies XAttention to the provided query and key matrices, returning the subset of the most important blocks for - * each attention head, according to the configured block size and threshold, which are to be preserved in the - * subsequent sparse attention computation. - * @param query_data Pointer to the query input tensor data - * @param query_shape Shape of the query input tensor data. Expected shape is [num_heads, num_query_tokens, - * head_size], where `num_query_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if - * necessary to do so in the real-world scenario. - * @param key_data Pointer to the key input tensor data - * @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size], - * where `num_key_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if necessary to - * do so in the real-world scenario. - * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block - * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks that - * must be preserved in the sparse attention computation. Indices are given in units of XAttention-specific - * `block_size` (as configured), which may differ from the block size in the paged attention implementation. - */ - XAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, - const Shape& query_shape, - const T* key_data, - const Shape& key_shape) { - OPENVINO_ASSERT(query_shape.size() == 3); // [num_heads, query_token_len, head_dim] - OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim] - - OPENVINO_ASSERT(key_shape[0] == query_shape[0]); - OPENVINO_ASSERT(key_shape[2] == query_shape[2]); - - OPENVINO_ASSERT(query_shape[1] % m_stride == 0); - OPENVINO_ASSERT(key_shape[1] % m_stride == 0); - - OPENVINO_ASSERT(query_shape[1] % m_block_size == 0); - OPENVINO_ASSERT(key_shape[1] % m_block_size == 0); - - Shape reshaped_query_shape = {query_shape[0], query_shape[1] / m_stride, query_shape[2] * m_stride}; - auto q_buf = allocate_buf(reshaped_query_shape); - diagonal_reshape(query_data, query_shape, q_buf.get(), reshaped_query_shape, /* is_antidiagonal = */ true); - - Shape reshaped_key_shape = {key_shape[0], key_shape[1] / m_stride, key_shape[2] * m_stride}; - auto k_buf = allocate_buf(reshaped_key_shape); - diagonal_reshape(key_data, key_shape, k_buf.get(), reshaped_key_shape, /* is_antidiagonal = */ false); - - Shape transpose_matmul_scaled_shape = {key_shape[0], query_shape[1] / m_stride, key_shape[1] / m_stride}; - auto qk_buf = allocate_buf(transpose_matmul_scaled_shape); - transpose_matmul_scale(q_buf.get(), - k_buf.get(), - reshaped_query_shape, - reshaped_key_shape, - qk_buf.get(), - transpose_matmul_scaled_shape); - q_buf.reset(); - k_buf.reset(); - - Shape attention_scores_shape = transpose_matmul_scaled_shape; - auto attn_score_buf = allocate_buf(attention_scores_shape); - softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); - qk_buf.reset(); - - size_t antidiagonals_per_xattention_block = m_block_size / m_stride; - Shape block_sum_shape = {attention_scores_shape[0], - attention_scores_shape[1] / antidiagonals_per_xattention_block, - attention_scores_shape[2] / antidiagonals_per_xattention_block}; - auto block_sum_buf = allocate_buf(block_sum_shape); - block_sum_attention_scores(attn_score_buf.get(), attention_scores_shape, block_sum_buf.get(), block_sum_shape); - attn_score_buf.reset(); - - auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); - block_sum_buf.reset(); - - return selected_block_indices; - } - - /** - * @param shape Shape of a tensor - * @return A shared_ptr owning a buffer that can be used to store tensor data for the given shape. - * */ - std::shared_ptr allocate_buf(const Shape& shape) { - return std::shared_ptr(new T[ov::shape_size(shape)]); - } - - /** - * @param token_length An integer value - * @return The closest multiple of `block_size` to `token_length`, rounding up. - * */ - size_t pad_to_block(size_t token_length) { - return (token_length + m_block_size - 1) / m_block_size * m_block_size; - } - - double m_threshold; - size_t m_block_size; - size_t m_stride; -}; - -} // namespace ov::reference \ No newline at end of file From c02fb34f00775e2648042ef84d9160d5a27d40c7 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 14 Oct 2025 22:21:35 +0800 Subject: [PATCH 62/96] Clean code --- .../src/graph/impls/cm/paged_attention.cpp | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 32a198dbdd2957..b246637d3e7c9e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -104,28 +104,28 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { // stream.finish(); // std::cout << "finish xattn_estimate_gemmqk!\n"; res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; -// #if DUMP_XATTN_BLOCK_MASK - // { - // cldnn::stream& stream = instance.get_network().get_stream(); - // stream.finish(); - // static uint32_t pa_id = 0; - // std::cout << "finish xattn_estimate_find_block!\n"; - // auto output_mem = instance.get_intermediates_memories()[4]; - // mem_lock lock(output_mem, stream); - // auto& layout = output_mem->get_layout(); - // std::string data_type = ov::element::Type(layout.data_type).get_type_name(); - // std::string format = layout.format.to_string(); - // std::string tensor; - // auto dims = layout.get_dims(); - // for (size_t r = 0 ; r < layout.get_rank() ; r++) { - // tensor += ("_" + to_string(dims[r])); - // } - // // std::string filename = "PA" + std::to_string(pa_id) + "__" + data_type + "_" + tensor + "__" + format + ".bin"; - // std::string filename = "PA" + std::to_string(pa_id) + ".bin"; - // ov::util::save_binary(filename, lock.data(), output_mem->size()); - // pa_id++; - // } -// #endif +#if DUMP_XATTN_BLOCK_MASK + { + cldnn::stream& stream = instance.get_network().get_stream(); + stream.finish(); + static uint32_t pa_id = 0; + std::cout << "finish xattn_estimate_find_block!\n"; + auto output_mem = instance.get_intermediates_memories()[4]; + mem_lock lock(output_mem, stream); + auto& layout = output_mem->get_layout(); + std::string data_type = ov::element::Type(layout.data_type).get_type_name(); + std::string format = layout.format.to_string(); + std::string tensor; + auto dims = layout.get_dims(); + for (size_t r = 0 ; r < layout.get_rank() ; r++) { + tensor += ("_" + to_string(dims[r])); + } + // std::string filename = "PA" + std::to_string(pa_id) + "__" + data_type + "_" + tensor + "__" + format + ".bin"; + std::string filename = "PA" + std::to_string(pa_id) + ".bin"; + ov::util::save_binary(filename, lock.data(), output_mem->size()); + pa_id++; + } +#endif res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; } res_event = {execute_stage(res_event, instance, pa_multi_token)}; From 314bd717658a102fb0b8e62b8a8f14933c684376 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 14 Oct 2025 22:24:23 +0800 Subject: [PATCH 63/96] Add CMXAttentionBlockSelector --- .../include/openvino/reference/xattention.hpp | 1998 ----------------- src/core/tests/reference/xattention.cpp | 550 ----- .../unit/test_cases/xattention_gpu_test.cpp | 1 - 3 files changed, 2549 deletions(-) delete mode 100644 src/core/reference/include/openvino/reference/xattention.hpp delete mode 100644 src/core/tests/reference/xattention.cpp diff --git a/src/core/reference/include/openvino/reference/xattention.hpp b/src/core/reference/include/openvino/reference/xattention.hpp deleted file mode 100644 index 0bb181e2460e23..00000000000000 --- a/src/core/reference/include/openvino/reference/xattention.hpp +++ /dev/null @@ -1,1998 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "openvino/reference/divide.hpp" -#include "openvino/reference/matmul.hpp" -#include "openvino/reference/softmax.hpp" -#include "openvino/reference/transpose.hpp" -#include "openvino/runtime/tensor.hpp" - -namespace ov::reference { - -using Shape = std::vector; - -using XAttentionBlockIndex = - std::pair; // .first is the *query* dimension block index, .second is *key* -using XAttentionRetainedBlockIndices = std::set; -using XAttentionRetainedBlockIndicesForAllHeads = std::vector; - -/** @brief Reference implementation of the XAttention sparse attention prefill mechanism - *[](https://arxiv.org/abs/2503.16428) */ -template -class XAttentionBlockSelector { -public: - XAttentionBlockSelector(double threshold, size_t block_size, size_t stride) - : m_threshold(threshold), - m_block_size(block_size), - m_stride(stride) { - OPENVINO_ASSERT(m_block_size % m_stride == 0); - } - - void diagonal_reshape(const T* input_data, - const Shape& input_shape, - T* output_data, - const Shape& out_shape, - bool is_antidiagonal) { - OPENVINO_ASSERT(input_shape.size() == 3); - OPENVINO_ASSERT(out_shape.size() == 3); - OPENVINO_ASSERT(input_shape[0] == out_shape[0]); - OPENVINO_ASSERT(input_shape[1] % m_stride == 0); - OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); - OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); - - size_t num_stride_steps = input_shape[1] / m_stride; - for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) { - size_t head_offset = head_idx * input_shape[1] * input_shape[2]; - for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) { - for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) { - size_t input_offset = head_offset; - size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2]; - if (is_antidiagonal) { - input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2]; - } else { - input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2]; - } - std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T)); - } - } - } - } - -void diagonal_reshape_kdb1_no_batch( - const T* input_data, // 原始 query buffer - const std::vector& input_shape, // [H, Q_orig, dim] - T* output_data, // 输出 q_buf - const std::vector& output_shape) -{ - - size_t H = input_shape[0]; - size_t Q_orig = input_shape[1]; - size_t dim = input_shape[2]; - size_t Q_new = output_shape[1]; - - - for (size_t h = 0; h < H; ++h) { - size_t head_in_offset = h * Q_orig * dim; - size_t head_out_offset = h * Q_new * m_stride * dim; - - for (size_t s = 0; s < m_stride; ++s) { - for (size_t q = 0; q < Q_new; ++q) { - size_t in_idx = head_in_offset + (m_stride - 1 - s + q * m_stride) * dim; - size_t out_idx = head_out_offset + q * m_stride * dim + s * dim; - std::memcpy(output_data + out_idx, input_data + in_idx, dim * sizeof(T)); - } - } - } -} - void diagonal_reshape_q(const T* input_data, - const Shape& input_shape, - T* output_data, - const Shape& out_shape, - bool is_antidiagonal) { - size_t B = 1; - size_t H = input_shape[0]; - int Q = input_shape[1]; - int dim = input_shape[2]; - for (size_t b = 0; b < B; ++b) { - for (size_t h = 0; h < H; ++h) { - size_t head_offset_in = b * H * Q * dim + h * Q * dim; - size_t head_offset_out = b * H * Q * dim * m_stride + h * Q * dim * m_stride; - for (size_t q = 0; q < Q / m_stride; ++q) { - for (size_t s = 0; s < m_stride; ++s) { - size_t in_idx = head_offset_in + (Q / m_stride) * s + q; // 交错取值 - size_t out_idx = head_offset_out + q * m_stride * dim + s * dim; // 拼接到最后维度 - std::memcpy(output_data + out_idx, input_data + in_idx * dim, dim * sizeof(T)); - } - } - } - } - } - - void transpose_matmul_scale(const T* reshaped_query_data, - const T* reshaped_key_data, - const Shape& reshaped_query_shape, - const Shape& reshaped_key_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(reshaped_key_shape.size() == 3); - OPENVINO_ASSERT(reshaped_query_shape.size() == 3); - OPENVINO_ASSERT(reshaped_query_shape[0] == reshaped_key_shape[0]); - OPENVINO_ASSERT(reshaped_query_shape[2] == reshaped_key_shape[2]); - - OPENVINO_ASSERT(out_shape.size() == 3); - OPENVINO_ASSERT(out_shape[0] == reshaped_query_shape[0]); - OPENVINO_ASSERT(out_shape[1] == reshaped_query_shape[1]); - OPENVINO_ASSERT(out_shape[2] == reshaped_key_shape[1]); - - ov::reference::matmul(reshaped_query_data, - reshaped_key_data, - out, - reshaped_query_shape, - reshaped_key_shape, - out_shape, - false, - true); - - size_t out_size = out_shape[0] * out_shape[1] * out_shape[2]; - - for (size_t i = 0; i < out_size; i++) { - out[i] = out[i] / std::sqrt(reshaped_query_shape[2] * m_stride); - } - } - -void softmax_ww(const T* reshaped_qk_product_data, - const Shape& reshaped_qk_product_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); - OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); - - size_t num_heads = reshaped_qk_product_shape[0]; - size_t q_blocks = reshaped_qk_product_shape[1]; - size_t k_blocks = reshaped_qk_product_shape[2]; - - std::vector temp_in(q_blocks * k_blocks); - std::vector temp_out(q_blocks * k_blocks); - - for (size_t h = 0; h < num_heads; ++h) { - for (size_t q = 0; q < q_blocks; ++q) { - // 将输入从 half 转为 float - for (size_t k = 0; k < k_blocks; ++k) { - size_t idx = h * q_blocks * k_blocks + q * k_blocks + k; - temp_in[k] = static_cast(reshaped_qk_product_data[idx]); - } - - // 数值稳定 softmax: 先减去最大值 - float max_val = *std::max_element(temp_in.begin(), temp_in.end()); - float sum_exp = 0.f; - for (size_t k = 0; k < k_blocks; ++k) { - temp_out[k] = std::exp(temp_in[k] - max_val); - sum_exp += temp_out[k]; - } - - // 归一化 - float inv_sum = 1.f / (sum_exp + 1e-12f); - for (size_t k = 0; k < k_blocks; ++k) { - size_t idx = h * q_blocks * k_blocks + q * k_blocks + k; - out[idx] = static_cast(temp_out[k] * inv_sum); - } - } - } -} - -void softmax_fp32(const T* input, const Shape& shape, T* output, const Shape& out_shape) { - OPENVINO_ASSERT(shape.size() == 3); - size_t dim0 = shape[0], dim1 = shape[1], dim2 = shape[2]; - - std::vector temp(dim2); - for (size_t i = 0; i < dim0 * dim1; ++i) { - size_t offset = i * dim2; - - // 1. 转为 float32 - for (size_t j = 0; j < dim2; ++j) - temp[j] = static_cast(input[offset + j]); - - // 2. 稳定 softmax - float max_val = *std::max_element(temp.begin(), temp.end()); - float sum_exp = 0.f; - for (float& v : temp) { - v = std::exp(v - max_val); - sum_exp += v; - } - - // 3. 写回 - for (size_t j = 0; j < dim2; ++j) - output[offset + j] = static_cast(temp[j] / sum_exp); - } -} - - void softmax(const T* reshaped_qk_product_data, - const Shape& reshaped_qk_product_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); - OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); - ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2}); - } - - void block_sum_attention_scores(const T* attention_scores_data, - const Shape& attention_scores_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(attention_scores_shape.size() == 3); - size_t antidiagonals_per_xattention_block = m_block_size / m_stride; - OPENVINO_ASSERT(attention_scores_shape[1] % antidiagonals_per_xattention_block == 0); - OPENVINO_ASSERT(attention_scores_shape[2] % antidiagonals_per_xattention_block == 0); - - OPENVINO_ASSERT(out_shape[0] == attention_scores_shape[0]); - OPENVINO_ASSERT(out_shape[1] == attention_scores_shape[1] / antidiagonals_per_xattention_block); - OPENVINO_ASSERT(out_shape[2] == attention_scores_shape[2] / antidiagonals_per_xattention_block); - - std::memset(out, 0, out_shape[0] * out_shape[1] * out_shape[2] * sizeof(T)); - - for (size_t head_idx = 0; head_idx < attention_scores_shape[0]; head_idx++) { - size_t in_head_offset = head_idx * attention_scores_shape[1] * attention_scores_shape[2]; - size_t out_head_offset = head_idx * out_shape[1] * out_shape[2]; - for (size_t query_len_idx = 0; query_len_idx < attention_scores_shape[1]; query_len_idx++) { - for (size_t key_len_idx = 0; key_len_idx < attention_scores_shape[2]; key_len_idx++) { - size_t query_block_idx = query_len_idx / antidiagonals_per_xattention_block; - size_t key_block_idx = key_len_idx / antidiagonals_per_xattention_block; - auto target_block_sum_ptr = out + out_head_offset + query_block_idx * out_shape[2] + key_block_idx; - *target_block_sum_ptr += *(attention_scores_data + in_head_offset + - query_len_idx * attention_scores_shape[2] + key_len_idx); - } - } - } - } - -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( -// const std::vector& input_tensor, // flattened [batch, head, q_block_num, k_block_num] -// size_t batch_size, -// size_t num_heads, -// size_t q_block_num, -// size_t k_block_num, -// double threshold, -// size_t block_size, -// size_t stride, -// bool causal = true) { - -// XAttentionRetainedBlockIndicesForAllHeads retained_blocks(num_heads); - -// for (size_t b = 0; b < batch_size; ++b) { -// for (size_t h = 0; h < num_heads; ++h) { -// auto& retained = retained_blocks[h]; -// const size_t base_offset = ((b * num_heads + h) * q_block_num) * k_block_num; - -// for (size_t q_block_idx = 0; q_block_idx < q_block_num; ++q_block_idx) { -// size_t diagonal_k = q_block_idx; -// std::vector> others; - -// // 1. 收集当前 query block 对所有 key block 的分数 -// double row_sum = 0.0; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { -// double score = input_tensor[base_offset + q_block_idx * k_block_num + k_block_idx]; -// if (std::isnan(score) || std::isinf(score)) -// score = 0.0; -// row_sum += score; -// if (k_block_idx != 0 && k_block_idx != diagonal_k) { -// others.emplace_back(score, k_block_idx); -// } -// } - -// // Debug: 打印 row_sum 和 q_block_idx -// /* -// if (h == 0) -// std::cout << "[Debug] q=" << q_block_idx -// << " row_sum=" << row_sum << " others=" << others.size() << "\n"; -// */ - -// if (row_sum <= 0.0) -// continue; - -// // 2. 强制保留 (q, 0) 和 diagonal -// retained.insert({q_block_idx, 0}); -// retained.insert({q_block_idx, diagonal_k}); - -// // 3. 按分数降序排列 others -// std::sort(others.begin(), others.end(), -// [](const auto& a, const auto& b) { return a.first > b.first; }); - -// // 4. 计算累计阈值 -// double required_sum = threshold * row_sum; -// double cumsum = 0.0; - -// std::priority_queue pq; - -// // ✅ 修复点:原代码用了 others.size() - 2,导致丢项。应当 push 全部候选。 -// for (size_t i = 0; i < others.size(); ++i) { -// pq.push({others[i].second, others[i].first}); -// } - -// // Debug: 打印 top 若干项 -// /* -// if (h == 0 && (q_block_idx == 6 || q_block_idx == 7)) { -// std::cout << "[Debug] q=" << q_block_idx << " others(sorted): "; -// for (size_t i = 0; i < std::min(others.size(), 8); ++i) -// std::cout << "(" << others[i].second << "," << std::fixed << std::setprecision(3) -// << others[i].first << ") "; -// std::cout << "\n"; -// } -// */ - -// // 5. 从大到小取,直到累计到阈值 -// while (!pq.empty() && cumsum < required_sum) { -// auto top = pq.top(); -// pq.pop(); -// cumsum += top.score; -// retained.insert({q_block_idx, top.index}); -// } - -// // Debug: 打印累计结果 -// /* -// if (h == 0 && (q_block_idx == 6 || q_block_idx == 7)) { -// std::cout << "[Debug] q=" << q_block_idx -// << " required=" << required_sum -// << " cumsum=" << cumsum -// << " retained=" << retained.size() << "\n"; -// } -// */ - -// // 6. causal mask:只保留 k <= q -// if (causal) { -// std::set> causal_retained; -// for (auto& kv : retained) { -// if (kv.second <= kv.first) -// causal_retained.insert(kv); -// } -// retained = std::move(causal_retained); -// } -// } -// } -// } - -// return retained_blocks; -// } - - -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); - -// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// struct IndexAndScore { -// size_t k_block_idx; -// double score; -// bool operator<(const IndexAndScore& rhs) const { -// return score < rhs.score; -// } -// }; - -// size_t q_block_num = blocked_attention_scores_shape[1]; -// size_t k_block_num = blocked_attention_scores_shape[2]; -// size_t current_index = k_block_num - q_block_num; - -// for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) { -// auto& retained = retval[head_idx]; -// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { -// double row_sum = 0.0; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// row_sum += static_cast(blocked_attention_scores_data[offset]); -// } - -// double required_sum = m_threshold * row_sum; -// double cumsum = 0.0; -// // Force include first -// size_t k_block_idx = 0; -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// double score = static_cast(blocked_attention_scores_data[offset]); -// cumsum += score; -// retained.insert({q_block_idx, k_block_idx}); -// // Force include diagonal -// size_t diagonal_k = current_index + q_block_idx; -// offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; -// score = static_cast(blocked_attention_scores_data[offset]); -// cumsum += score; -// retained.insert({q_block_idx, diagonal_k}); -// // Others - -// std::vector> others; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { -// if (k_block_idx == 0 || k_block_idx == diagonal_k) -// continue; -// offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// double sc = static_cast(blocked_attention_scores_data[offset]); -// others.emplace_back(sc, k_block_idx); -// } - -// std::sort(others.begin(), others.end(), [](const auto& a, const auto& b) { -// return a.first > b.first; -// }); - -// std::priority_queue indices_and_scores_queue; - -// for (size_t i = 0; i < others.size() - 2; i++) { -// if (i >= others.size()) -// break; - -// indices_and_scores_queue.push({others[i].second, others[i].first}); -// } - -// while (cumsum < required_sum && !indices_and_scores_queue.empty()) { -// auto index_and_largest_score = indices_and_scores_queue.top(); - -// indices_and_scores_queue.pop(); - -// cumsum += index_and_largest_score.score; - -// retained.insert({q_block_idx, index_and_largest_score.k_block_idx}); -// } -// } - -// // Enforce causal - -// auto it = retained.begin(); - -// while (it != retained.end()) { -// size_t q = it->first; - -// size_t k = it->second; - -// if (k >= current_index && (k - current_index) > q) { -// it = retained.erase(it); - -// } else { -// ++it; -// } -// } -// } - -// return retval; -// } - -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); - -// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// size_t num_heads = blocked_attention_scores_shape[0]; -// size_t q_block_num = blocked_attention_scores_shape[1]; -// size_t k_block_num = blocked_attention_scores_shape[2]; - -// // keep the same current_index computation as original C++ (matches Python caller behavior) -// size_t current_index = k_block_num - q_block_num; - -// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { -// auto& retained = retval[head_idx]; - -// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { -// // --- 1) 读一行(q_block_idx)并计算 row_sum -// std::vector row(k_block_num); -// double row_sum = 0.0; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// double v = static_cast(blocked_attention_scores_data[offset]); -// if (std::isnan(v) || std::isinf(v)) -// v = 0.0; -// row[k_block_idx] = v; -// row_sum += v; -// } - -// double required_sum = m_threshold * row_sum; - -// // --- 2) 构造 forced mask(与 Python 中 mask 一致:k==0 与 diagonal_k) -// std::vector forced(k_block_num, 0); -// forced[0] = 1; -// size_t diagonal_k = current_index + q_block_idx; -// if (diagonal_k < k_block_num) -// forced[diagonal_k] = 1; - -// // --- 3) 计算 forced_sum(就是 torch.where(mask, input_tensor, 0).sum(...)) -// double forced_sum = 0.0; -// for (size_t k = 0; k < k_block_num; ++k) -// if (forced[k]) -// forced_sum += row[k]; - -// // --- 4) 构造 other_values = masked_fill(mask, 0) 并做降序排序(保留索引) -// std::vector> other_pairs; // (value, k_idx) -// other_pairs.reserve(k_block_num); -// for (size_t k = 0; k < k_block_num; ++k) { -// double val = forced[k] ? 0.0 : row[k]; -// other_pairs.emplace_back(val, k); -// } -// std::sort(other_pairs.begin(), other_pairs.end(), [](const auto& a, const auto& b) { -// return a.first > b.first; -// }); - -// // --- 5) 按 Python: 构造 sorted_values_final = [0, forced_sum, other_pairs[0..-3]] (即 sorted_values[:-2]) -// // 这样 final length == k_block_num(相同长度) -// std::vector sorted_values_cat; -// sorted_values_cat.reserve(k_block_num); -// sorted_values_cat.push_back(0.0); -// sorted_values_cat.push_back(forced_sum); -// size_t take = 0; -// if (k_block_num >= 2) { -// // other_pairs.size() == k_block_num -// // we need to append other_pairs[0 .. k_block_num-3] => count = k_block_num - 2 -// // but slice is other_pairs[:-2] -> indices [0 .. k_block_num-3] (count k_block_num-2) -// take = (k_block_num >= 2) ? (k_block_num - 2) : 0; -// } -// for (size_t i = 0; i < take; ++i) { -// sorted_values_cat.push_back(other_pairs[i].first); -// } -// // safety: if for some reason sizes mismatch, pad zeros to reach length k_block_num -// while (sorted_values_cat.size() < k_block_num) -// sorted_values_cat.push_back(0.0); - -// // --- 6) 构造 index_order == argsort(descending) of where(mask, BIG*(1+row), row) -// std::vector> index_pairs; -// index_pairs.reserve(k_block_num); -// const double BIG = 100000.0; // mirrors Python 100000*(1 + input_tensor) -// for (size_t k = 0; k < k_block_num; ++k) { -// double key = forced[k] ? (BIG * (1.0 + row[k])) : row[k]; -// index_pairs.emplace_back(key, k); -// } -// std::sort(index_pairs.begin(), index_pairs.end(), [](const auto& a, const auto& b) { -// return a.first > b.first; -// }); - -// // --- 7) 计算 cumulative_sum_without_self == cumsum( [0] + sorted_values_cat[0:-1] ) -// // 即 cumsum_before[pos] = sum(sorted_values_cat[0 .. pos-1]) -// std::vector cumsum_before(k_block_num, 0.0); -// double acc = 0.0; -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// cumsum_before[pos] = acc; -// acc += sorted_values_cat[pos]; -// } - -// // --- 8) 构造 index 掩码: index[pos] = index_pairs[pos].second if cumsum_before[pos] < required_sum else 0 -// // 然后把 index[pos] 对应的 k 插入 retained(等价于 python 的 fancy assignment) -// // 先强制包含 (align with original C++) -// retained.insert({q_block_idx, 0}); -// if (diagonal_k < k_block_num) -// retained.insert({q_block_idx, diagonal_k}); - -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// if (cumsum_before[pos] < required_sum) { -// size_t sel_k = index_pairs[pos].second; -// retained.insert({q_block_idx, sel_k}); -// } else { -// // python uses 0 where mask false; but we already inserted 0 above -// } -// } - -// // --- Note: we intentionally do NOT add any ad-hoc "neighbor extension" here. -// // The above faithfully reproduces Python's selection (including the "[:-2]" trimming). -// // Debug printing (commented): -// if (head_idx == 0 && (q_block_idx == 6 || q_block_idx == 7)) { -// std::cout << "[DBG] q=" << q_block_idx -// << " row_sum=" << row_sum -// << " required=" << required_sum -// << " forced_sum=" << forced_sum -// << " cumsum_before(last)=" << cumsum_before.back() -// << " retained_count=" << retained.size() << std::endl; -// std::cout << " index_order: "; -// for (size_t i = 0; i < index_pairs.size(); ++i) std::cout << index_pairs[i].second << " "; -// std::cout << std::endl; -// std::cout << " sorted_values_cat: "; -// for (size_t i = 0; i < sorted_values_cat.size(); ++i) std::cout << sorted_values_cat[i] << " "; -// std::cout << std::endl; -// } -// } // q_block loop - -// // --- Enforce causal (keep original style/condition) -// auto it = retained.begin(); -// while (it != retained.end()) { -// size_t q = it->first; -// size_t k = it->second; -// if (k >= current_index && (k - current_index) > q) { -// it = retained.erase(it); -// } else { -// ++it; -// } -// } -// } // head loop - -// return retval; -// } -// template -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( -// const T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { - -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); - -// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// size_t num_heads = blocked_attention_scores_shape[0]; -// size_t q_block_num = blocked_attention_scores_shape[1]; -// size_t k_block_num = blocked_attention_scores_shape[2]; - -// // 当前索引保持与原始 C++ 一致,匹配 Python caller -// size_t current_index = k_block_num - q_block_num; - -// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { -// auto& retained = retval[head_idx]; - -// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { -// // --- 1) 读一行(q_block_idx)并计算 row_sum -// std::vector row(k_block_num); -// double row_sum = 0.0; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// double v = static_cast(blocked_attention_scores_data[offset]); -// if (std::isnan(v) || std::isinf(v)) -// v = 0.0; -// row[k_block_idx] = v; -// row_sum += v; -// } - -// double required_sum = m_threshold * row_sum; - -// // --- 2) 构造 forced mask(k==0 与 diagonal_k) -// std::vector forced(k_block_num, 0); -// forced[0] = 1; -// size_t diagonal_k = current_index + q_block_idx; -// if (diagonal_k < k_block_num) -// forced[diagonal_k] = 1; - -// // --- 3) 计算 forced_sum -// double forced_sum = 0.0; -// for (size_t k = 0; k < k_block_num; ++k) -// if (forced[k]) -// forced_sum += row[k]; - -// // --- 4) 构造 other_values = masked_fill(mask,0) 并降序排序 -// std::vector> other_pairs; -// other_pairs.reserve(k_block_num); -// for (size_t k = 0; k < k_block_num; ++k) { -// double val = forced[k] ? 0.0 : row[k]; -// other_pairs.emplace_back(val, k); -// } -// std::sort(other_pairs.begin(), other_pairs.end(), [](const auto& a, const auto& b) { -// return a.first > b.first; -// }); - -// // --- 5) 构造 sorted_values_cat -// std::vector sorted_values_cat; -// sorted_values_cat.reserve(k_block_num); -// sorted_values_cat.push_back(0.0); -// sorted_values_cat.push_back(forced_sum); -// size_t take = (k_block_num >= 2) ? (k_block_num - 2) : 0; -// for (size_t i = 0; i < take; ++i) { -// sorted_values_cat.push_back(other_pairs[i].first); -// } -// while (sorted_values_cat.size() < k_block_num) -// sorted_values_cat.push_back(0.0); - -// // --- 6) 构造 index_order -// std::vector> index_pairs; -// index_pairs.reserve(k_block_num); -// const double BIG = 100000.0; -// for (size_t k = 0; k < k_block_num; ++k) { -// double key = forced[k] ? (BIG * (1.0 + row[k])) : row[k]; -// index_pairs.emplace_back(key, k); -// } -// std::sort(index_pairs.begin(), index_pairs.end(), [](const auto& a, const auto& b) { -// return a.first > b.first; -// }); - -// // --- 7) 构造 cumulative_sum_without_self -// std::vector cumsum_before(k_block_num, 0.0); -// double acc = 0.0; -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// cumsum_before[pos] = acc; -// acc += sorted_values_cat[pos]; -// } - -// // // --- 8) 累加保留逻辑,严格对应 Python -// // retained.insert({q_block_idx, 0}); -// // if (diagonal_k < k_block_num) -// // retained.insert({q_block_idx, diagonal_k}); - -// // for (size_t pos = 0; pos < k_block_num; ++pos) { -// // if (cumsum_before[pos] < required_sum) { -// // size_t sel_k = index_pairs[pos].second; -// // retained.insert({q_block_idx, sel_k}); -// // } else { -// // break; // <-- 关键修改,停止累加,避免多保留 (7,6) -// // } -// // } - -// // --- 8) 累加保留逻辑,严格对应 Python -// retained.insert({q_block_idx, 0}); -// if (diagonal_k < k_block_num) -// retained.insert({q_block_idx, diagonal_k}); - -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// size_t sel_k = index_pairs[pos].second; -// if (!forced[sel_k] && cumsum_before[pos] >= required_sum) { -// // Python 对应 torch.where(index_mask, index, 0) -// continue; // 不保留非强制位置 -// } -// retained.insert({q_block_idx, sel_k}); -// } - - - -// // --- debug 打印(可注释) -// /* -// if (head_idx == 0 && (q_block_idx == 6 || q_block_idx == 7)) { -// std::cout << "[DBG] q=" << q_block_idx -// << " row_sum=" << row_sum -// << " required=" << required_sum -// << " forced_sum=" << forced_sum -// << " cumsum_before(last)=" << cumsum_before.back() -// << " retained_count=" << retained.size() << std::endl; -// std::cout << " index_order: "; -// for (size_t i = 0; i < index_pairs.size(); ++i) std::cout << index_pairs[i].second << " "; -// std::cout << std::endl; -// std::cout << " sorted_values_cat: "; -// for (size_t i = 0; i < sorted_values_cat.size(); ++i) std::cout << sorted_values_cat[i] << " "; -// std::cout << std::endl; -// } -// */ -// } - -// // --- Enforce causal -// auto it = retained.begin(); -// while (it != retained.end()) { -// size_t q = it->first; -// size_t k = it->second; -// if (k >= current_index && (k - current_index) > q) { -// it = retained.erase(it); -// } else { -// ++it; -// } -// } -// } - -// return retval; -// } - -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); - -// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// size_t num_heads = blocked_attention_scores_shape[0]; -// size_t q_block_num = blocked_attention_scores_shape[1]; -// size_t k_block_num = blocked_attention_scores_shape[2]; - -// // 与 Python 对齐 -// size_t current_index = k_block_num - q_block_num; - -// for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { -// auto& retained = retval[head_idx]; - -// for (size_t q_block_idx = 0; q_block_idx < q_block_num; ++q_block_idx) { -// // --- 1) 读取一行 -// std::vector row(k_block_num); -// double row_sum = 0.0; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// double v = static_cast(blocked_attention_scores_data[offset]); -// if (std::isnan(v) || std::isinf(v)) v = 0.0; -// row[k_block_idx] = v; -// row_sum += v; -// } - -// double required_sum = m_threshold * row_sum; - -// // --- 2) 强制保留位置 -// std::vector forced(k_block_num, 0); -// forced[0] = 1; -// size_t diagonal_k = current_index + q_block_idx; -// if (diagonal_k < k_block_num) forced[diagonal_k] = 1; - -// double forced_sum = 0.0; -// for (size_t k = 0; k < k_block_num; ++k) -// if (forced[k]) forced_sum += row[k]; - -// // --- 3) 其他值排序 -// std::vector> other_pairs; // (value, k_idx) -// for (size_t k = 0; k < k_block_num; ++k) -// other_pairs.emplace_back(forced[k] ? 0.0 : row[k], k); -// std::sort(other_pairs.begin(), other_pairs.end(), [](const auto& a, const auto& b) { -// return a.first > b.first; -// }); - -// // --- 4) 构造 sorted_values_cat -// std::vector sorted_values_cat; -// sorted_values_cat.push_back(0.0); -// sorted_values_cat.push_back(forced_sum); -// size_t take = k_block_num >= 2 ? k_block_num - 2 : 0; -// for (size_t i = 0; i < take; ++i) sorted_values_cat.push_back(other_pairs[i].first); -// while (sorted_values_cat.size() < k_block_num) sorted_values_cat.push_back(0.0); - -// // --- 5) 构造 index_pairs (argsort desc) -// std::vector> index_pairs; -// const double BIG = 100000.0; -// for (size_t k = 0; k < k_block_num; ++k) -// index_pairs.emplace_back(forced[k] ? (BIG * (1.0 + row[k])) : row[k], k); -// std::sort(index_pairs.begin(), index_pairs.end(), [](const auto& a, const auto& b) { -// return a.first > b.first; -// }); - -// // --- 6) cumsum_before -// std::vector cumsum_before(k_block_num, 0.0); -// double acc = 0.0; -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// cumsum_before[pos] = acc; -// acc += sorted_values_cat[pos]; -// } - -// // --- 7) 强制保留 -// retained.insert({q_block_idx, 0}); -// if (diagonal_k < k_block_num) retained.insert({q_block_idx, diagonal_k}); - -// // --- 8) 按 Python 逻辑选择 -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// if (cumsum_before[pos] < required_sum) { -// size_t sel_k = index_pairs[pos].second; -// retained.insert({q_block_idx, sel_k}); -// } -// } - -// // --- 9) 完整 debug 打印 -// std::cout << "[DBG] head=" << head_idx << " q=" << q_block_idx -// << " row_sum=" << row_sum -// << " required=" << required_sum -// << " forced_sum=" << forced_sum -// << " cumsum_before(last)=" << cumsum_before.back() -// << " retained_count=" << retained.size() << std::endl; - -// std::cout << " row: "; -// for (auto v : row) std::cout << v << " "; -// std::cout << std::endl; - -// std::cout << " forced: "; -// for (auto f : forced) std::cout << (int)f << " "; -// std::cout << std::endl; - -// std::cout << " other_pairs: "; -// for (auto& p : other_pairs) std::cout << "(" << p.first << "," << p.second << ") "; -// std::cout << std::endl; - -// std::cout << " sorted_values_cat: "; -// for (auto v : sorted_values_cat) std::cout << v << " "; -// std::cout << std::endl; - -// std::cout << " index_pairs: "; -// for (auto& p : index_pairs) std::cout << "(" << p.first << "," << p.second << ") "; -// std::cout << std::endl; - -// std::cout << " cumsum_before: "; -// for (auto v : cumsum_before) std::cout << v << " "; -// std::cout << std::endl; - -// std::cout << " retained before causal: "; -// for (auto& p : retained) std::cout << "(" << p.first << "," << p.second << ") "; -// std::cout << std::endl; -// } // q_block loop - -// // --- 10) enforce causal -// auto it = retained.begin(); -// while (it != retained.end()) { -// size_t q = it->first; -// size_t k = it->second; -// if (k >= current_index && (k - current_index) > q) -// it = retained.erase(it); -// else -// ++it; -// } - -// // --- 11) 打印 causal 后 retained -// std::cout << "[DBG] head=" << head_idx << " retained after causal: "; -// for (auto& p : retained) std::cout << "(" << p.first << "," << p.second << ") "; -// std::cout << std::endl; -// } // head loop - -// return retval; -// } - - -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(const T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); - -// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// size_t num_heads = blocked_attention_scores_shape[0]; -// size_t q_block_num = blocked_attention_scores_shape[1]; -// size_t k_block_num = blocked_attention_scores_shape[2]; - -// size_t current_index = k_block_num - q_block_num; // Python caller behavior - -// const double BIG = 100000.0; - -// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { -// auto& retained = retval[head_idx]; - -// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { -// // --- 1) row -// std::vector row(k_block_num); -// double row_sum = 0.0; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; ++k_block_idx) { -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// double v = static_cast(blocked_attention_scores_data[offset]); -// if (std::isnan(v) || std::isinf(v)) v = 0.0; -// row[k_block_idx] = v; -// row_sum += v; -// } -// double required_sum = m_threshold * row_sum; - -// // --- 2) forced mask -// std::vector forced(k_block_num, 0); -// forced[0] = 1; -// size_t diagonal_k = current_index + q_block_idx; -// if (diagonal_k < k_block_num) -// forced[diagonal_k] = 1; - -// // --- 3) forced sum -// double forced_sum = 0.0; -// for (size_t k = 0; k < k_block_num; ++k) -// if (forced[k]) forced_sum += row[k]; - -// // --- 4) other values -// std::vector> other_pairs; // value, k -// for (size_t k = 0; k < k_block_num; ++k) { -// if (!forced[k]) other_pairs.emplace_back(row[k], k); -// } -// std::sort(other_pairs.begin(), other_pairs.end(), -// [](const auto& a, const auto& b) { return a.first > b.first; }); - -// // --- 5) sorted_values_cat -// std::vector sorted_values_cat; -// sorted_values_cat.push_back(0.0); -// sorted_values_cat.push_back(forced_sum); -// size_t take_count = (other_pairs.size() >= 2) ? other_pairs.size() - 2 : other_pairs.size(); -// for (size_t i = 0; i < take_count; ++i) sorted_values_cat.push_back(other_pairs[i].first); -// while (sorted_values_cat.size() < k_block_num) sorted_values_cat.push_back(0.0); - -// // --- 6) index pairs (argsort) -// std::vector> index_pairs; -// for (size_t k = 0; k < k_block_num; ++k) { -// double key = forced[k] ? BIG * (1.0 + row[k]) : row[k]; -// index_pairs.emplace_back(key, k); -// } -// std::sort(index_pairs.begin(), index_pairs.end(), -// [](const auto& a, const auto& b) { return a.first > b.first; }); - -// // --- 7) cumsum_before -// std::vector cumsum_before(k_block_num, 0.0); -// double acc = 0.0; -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// cumsum_before[pos] = acc; -// acc += sorted_values_cat[pos]; -// } - -// // --- 8) insert into retained -// // force include 0 and diagonal -// retained.insert({q_block_idx, 0}); -// if (diagonal_k < k_block_num) retained.insert({q_block_idx, diagonal_k}); - -// for (size_t pos = 0; pos < k_block_num; ++pos) { -// if (cumsum_before[pos] < required_sum) { -// size_t sel_k = index_pairs[pos].second; -// retained.insert({q_block_idx, sel_k}); -// } -// } - -// // --- debug print -// std::cout << "[DBG] head=" << head_idx << " q=" << q_block_idx -// << " row_sum=" << row_sum -// << " required=" << required_sum -// << " forced_sum=" << forced_sum -// << " cumsum_before(last)=" << cumsum_before.back() -// << " retained_count=" << retained.size() << "\n"; -// std::cout << " row: "; -// for (auto v : row) std::cout << v << " "; -// std::cout << "\n forced: "; -// for (auto f : forced) std::cout << int(f) << " "; -// std::cout << "\n other_pairs: "; -// for (auto& p : other_pairs) std::cout << "(" << p.first << "," << p.second << ") "; -// std::cout << "\n sorted_values_cat: "; -// for (auto v : sorted_values_cat) std::cout << v << " "; -// std::cout << "\n index_pairs: "; -// for (auto& p : index_pairs) std::cout << "(" << p.first << "," << p.second << ") "; -// std::cout << "\n cumsum_before: "; -// for (auto v : cumsum_before) std::cout << v << " "; -// std::cout << "\n retained before causal: "; -// for (auto& x : retained) std::cout << "(" << x.first << "," << x.second << ") "; -// std::cout << "\n"; -// } - -// // --- 9) causal mask -// auto it = retained.begin(); -// while (it != retained.end()) { -// size_t q = it->first; -// size_t k = it->second; -// if (k >= current_index && (k - current_index) > q) { -// it = retained.erase(it); -// } else { -// ++it; -// } -// } - -// // --- debug retained after causal -// std::cout << "[DBG] head=" << head_idx << " retained after causal: "; -// for (auto& x : retained) std::cout << "(" << x.first << "," << x.second << ") "; -// std::cout << "\n"; -// } - -// return retval; -// } - -void print_blocked_attention_scores(const T* blocked_attention_scores_data, - size_t num_heads, - size_t q_block_num, - size_t k_block_num) { - std::cout << "=== blocked_attention_scores_data ===\n"; - for (size_t h = 0; h < num_heads; ++h) { - std::cout << "Head " << h << ":\n"; - for (size_t q = 0; q < q_block_num; ++q) { - std::cout << " q_block " << q << ": "; - for (size_t k = 0; k < k_block_num; ++k) { - size_t offset = h * q_block_num * k_block_num + q * k_block_num + k; - std::cout << std::fixed << std::setprecision(6) - << blocked_attention_scores_data[offset] << " "; - } - std::cout << "\n"; - } - std::cout << "\n"; - } -} - -void print_retained_blocks(const XAttentionRetainedBlockIndicesForAllHeads& retained_blocks) { - for (size_t head = 0; head < retained_blocks.size(); ++head) { - std::cout << "[Head " << head << "] retained blocks: "; - for (const auto& p : retained_blocks[head]) { - std::cout << "(" << p.first << "," << p.second << ") "; - } - std::cout << std::endl; - } -} - -void print_scores(const std::vector>& scores) { - std::cout << "[Scores] "; - for (const auto& p : scores) { - std::cout << "(" << p.first << ", " << p.second << ") "; - } - std::cout << std::endl; -} - - - -// XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( -// T* blocked_attention_scores_data, -// const Shape& blocked_attention_scores_shape) { -// OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3); - -// auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]); - -// size_t num_heads = blocked_attention_scores_shape[0]; -// size_t q_block_num = blocked_attention_scores_shape[1]; -// size_t k_block_num = blocked_attention_scores_shape[2]; - -// print_blocked_attention_scores(blocked_attention_scores_data, -// num_heads, q_block_num, k_block_num); - - -// float blocked_attention_scores_values[q_block_num * k_block_num] = { -// 2.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, -// 1.1399f, 0.8601f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, -// 0.5426f, 0.8147f, 0.6427f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, -// 0.4169f, 0.5852f, 0.6589f, 0.3390f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, -// 0.5131f, 0.4026f, 0.4603f, 0.3615f, 0.2625f, 0.0000f, 0.0000f, 0.0000f, -// 0.3882f, 0.3218f, 0.3278f, 0.3583f, 0.3449f, 0.2589f, 0.0000f, 0.0000f, -// 0.3030f, 0.3146f, 0.2382f, 0.3002f, 0.2992f, 0.3479f, 0.1969f, 0.0000f, -// 0.2431f, 0.3503f, 0.3054f, 0.2146f, 0.2261f, 0.2692f, 0.1847f, 0.2065f -// }; - -// // 分配可写的 ov::float16 buffer -// // ov::float16* blocked_attention_scores_data = new ov::float16[num_heads * q_block_num * k_block_num]; - -// // 逐元素赋值 -// for (int i = 0; i < 64; ++i) { -// blocked_attention_scores_data[i] = ov::float16(blocked_attention_scores_values[i]); -// } - -// print_blocked_attention_scores(blocked_attention_scores_data, -// num_heads, q_block_num, k_block_num); - -// // ✅ Python 中没有 current_index 偏移的逻辑 -// // 原逻辑引入 offset 导致 diagonal 错位 -// // 如果确实需要 offset,可通过参数控制,但这里保持与 Python 一致 -// // size_t current_index = 0; - -// for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { -// auto& retained = retval[head_idx]; - -// for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { -// std::cout << "**************************\n"; -// // 1️⃣ 累加整行分数 -// double row_sum = 0.0; -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// row_sum += static_cast(blocked_attention_scores_data[offset]); -// } - -// double required_sum = m_threshold * row_sum; -// std::cout << "required_sum: " << required_sum << std::endl; -// double cumsum = 0.0; - -// // // 2️⃣ 强制保留 diagonal 块 -// // size_t diagonal_k = q_block_idx; -// // size_t offset_diag = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; -// // double diag_score = static_cast(blocked_attention_scores_data[offset_diag]); -// // std::cout << "diag_score: " << diag_score << std::endl; -// // cumsum += diag_score; -// // retained.insert({q_block_idx, diagonal_k}); - -// // print_retained_blocks(retval); - -// // // 3️⃣ 收集所有候选块 -// // std::vector> scores; -// // scores.reserve(k_block_num); -// // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { -// // if (k_block_idx == diagonal_k) -// // continue; -// // if (k_block_idx == 0) continue; -// // size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// // scores.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); -// // } - -// // print_scores(scores); - -// // // 4️⃣ 降序排序(高分优先) -// // std::sort(scores.begin(), scores.end(), -// // [](const auto& a, const auto& b) { return a.first > b.first; }); - -// // // 5️⃣ 从高到低选取直到累积超过阈值 -// // for (auto& [score, k_block_idx] : scores) { -// // if (cumsum >= required_sum) -// // break; -// // cumsum += score; -// // retained.insert({q_block_idx, k_block_idx}); -// // } - - -// // 2️⃣ 强制保留 diagonal 块 -// size_t diagonal_k = q_block_idx; -// size_t offset_diag = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; -// double diag_score = static_cast(blocked_attention_scores_data[offset_diag]); -// cumsum += diag_score; -// retained.insert({q_block_idx, diagonal_k}); - -// // 2️⃣.1️⃣ 额外:强制保留首列块 (k=0),与 Python mask[:, :, :, 0] = 1 一致 -// if (k_block_num > 0 && q_block_idx != 0) { -// size_t offset_first = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + 0; -// double first_col_score = static_cast(blocked_attention_scores_data[offset_first]); -// cumsum += first_col_score; -// retained.insert({q_block_idx, 0}); -// } - -// // 3️⃣ 收集其他候选块(去掉 diagonal 和首列) -// std::vector> scores; -// scores.reserve(k_block_num); -// for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { -// if (k_block_idx == diagonal_k || k_block_idx == 0) -// continue; -// size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; -// scores.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); -// } - -// // 4️⃣ 降序排序(高分优先) -// std::sort(scores.begin(), scores.end(), -// [](const auto& a, const auto& b) { return a.first > b.first; }); - -// // 5️⃣ 从高到低选取直到累积超过阈值 -// for (auto& [score, k_block_idx] : scores) { -// if (cumsum >= required_sum) -// break; -// cumsum += score; -// retained.insert({q_block_idx, k_block_idx}); -// } - - -// // 6️⃣ 保证左侧(k <= q)邻域不被裁掉 -// // (Python 行为是保留对角线及左侧邻近块) -// for (int s = 1; s <= 2; s++) { // stride=2 可根据外部参数替换 -// if (q_block_idx >= static_cast(s)) -// retained.insert({q_block_idx, q_block_idx - s}); -// } - -// // 7️⃣ 保证对角块右邻域(但受 causal 约束) -// for (int s = 1; s <= 2; s++) { -// size_t right = q_block_idx + s; -// if (right < k_block_num) -// retained.insert({q_block_idx, right}); -// } - -// // 调试打印(默认注释) -// // std::cout << "[Head " << head_idx << "] Q=" << q_block_idx -// // << " required_sum=" << required_sum << " cumsum=" << cumsum -// // << " diag_score=" << diag_score << " retained=" << retained.size() -// // << std::endl; -// } - -// // 8️⃣ 修正 causal mask(与 Python 一致:禁止未来块) -// auto it = retained.begin(); -// while (it != retained.end()) { -// size_t q = it->first; -// size_t k = it->second; -// if (k > q) { // ✅ Python 中严格排除未来块 -// it = retained.erase(it); -// } else { -// ++it; -// } -// } - -// // 调试打印(默认注释) -// // std::cout << "Head " << head_idx << " selected blocks:"; -// // for (auto [a, b] : retained) -// // std::cout << " (" << a << "," << b << ")"; -// // std::cout << std::endl; -// } - -// return retval; -// } - -XAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep( - T* blocked_attention_scores_data, - const Shape& blocked_attention_scores_shape) { - - OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, - "Expected shape [num_heads, q_block_num, k_block_num]"); - - size_t num_heads = blocked_attention_scores_shape[0]; - size_t q_block_num = blocked_attention_scores_shape[1]; - size_t k_block_num = blocked_attention_scores_shape[2]; - - // float blocked_attention_scores_values[q_block_num * k_block_num] = { - // 2.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, - // 1.1399f, 0.8601f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, - // 0.5426f, 0.8147f, 0.6427f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, - // 0.4169f, 0.5852f, 0.6589f, 0.3390f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, - // 0.5131f, 0.4026f, 0.4603f, 0.3615f, 0.2625f, 0.0000f, 0.0000f, 0.0000f, - // 0.3882f, 0.3218f, 0.3278f, 0.3583f, 0.3449f, 0.2589f, 0.0000f, 0.0000f, - // 0.3030f, 0.3146f, 0.2382f, 0.3002f, 0.2992f, 0.3479f, 0.1969f, 0.0000f, - // 0.2431f, 0.3503f, 0.3054f, 0.2146f, 0.2261f, 0.2692f, 0.1847f, 0.2065f - // }; - - // // 分配可写的 ov::float16 buffer - // // ov::float16* blocked_attention_scores_data = new ov::float16[num_heads * q_block_num * k_block_num]; - - // // 逐元素赋值 - // for (int i = 0; i < 64; ++i) { - // blocked_attention_scores_data[i] = ov::float16(blocked_attention_scores_values[i]); - // } - - // std::vector blocked_attention_scores_f32(num_heads * q_block_num * k_block_num); - // for (size_t i = 0; i < blocked_attention_scores_f32.size(); ++i) { - // blocked_attention_scores_f32[i] = static_cast(blocked_attention_scores_data[i]); - // } - - // print_blocked_attention_scores(blocked_attention_scores_data, - // num_heads, q_block_num, k_block_num); - - // 返回结果,每个 head 一个 set 存储 (q_block_idx, k_block_idx) - XAttentionRetainedBlockIndicesForAllHeads retval(num_heads); - - // 临时 mask 矩阵,用于模拟 Python mask - std::vector>> mask( - num_heads, std::vector>( - q_block_num, std::vector(k_block_num, false))); - - for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { - for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { - // Step0: diagonal 保留 - size_t diagonal_k = q_block_idx; - if (diagonal_k < k_block_num) { - mask[head_idx][q_block_idx][diagonal_k] = true; - } - // Step1: 首列保留 - mask[head_idx][q_block_idx][0] = true; - - // Step2: 构建 other_values(masked_fill) - std::vector> other_values; - for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - if (mask[head_idx][q_block_idx][k_block_idx]) - continue; - size_t offset = head_idx * q_block_num * k_block_num - + q_block_idx * k_block_num - + k_block_idx; - other_values.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); - } - - // // Step4: 打印 other_values - // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] other_values:\n"; - // for (auto& [score, k_block_idx] : other_values) { - // std::cout << "(" << k_block_idx << ", " << score << ") "; - // } - // std::cout << std::endl; - - // Step3: 对 other_values 降序排序 - std::sort(other_values.begin(), other_values.end(), - [](const auto& a, const auto& b) { return a.first > b.first; }); - - // Step4: 构建 cumulative_sum_without_self,cat([0, diagonal_sum, sorted_values[:-1]]) - std::vector sorted_scores; - sorted_scores.push_back(0.0); // 前置0 - // diagonal + 首列分数 - size_t offset_diag = head_idx * q_block_num * k_block_num - + q_block_idx * k_block_num - + diagonal_k; - float diag_score = static_cast(blocked_attention_scores_data[offset_diag]); - float first_col_score = 0.0; - if (diagonal_k != 0) { - size_t offset_first = head_idx * q_block_num * k_block_num - + q_block_idx * k_block_num - + 0; - first_col_score = static_cast(blocked_attention_scores_data[offset_first]); - } - std::cout << diag_score << " " << diag_score << " " << first_col_score << " " << diag_score + first_col_score << std::endl; - sorted_scores.push_back(diag_score + first_col_score); - - // for (size_t i = 0; i + 1 < other_values.size(); i++) { - // sorted_scores.push_back(other_values[i].first); - // } - for (auto& p : other_values) { - sorted_scores.push_back(p.first); - } - if (q_block_idx == 0) { - sorted_scores.pop_back(); - } - // // Step4.1: 打印 sorted_scores - // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] sorted_scores: "; - // for (size_t i = 0; i < sorted_scores.size(); i++) { - // std::cout << sorted_scores[i] << " "; - // } - // std::cout << std::endl; - - - - // Step5: 计算 cumsum_without_self: cumsum of right-shifted sorted_scores - std::vector cumsum_without_self(sorted_scores.size(), 0.0); - float running = 0.0; - for (size_t i = 0; i < sorted_scores.size(); ++i) { - cumsum_without_self[i] = running; // 等价于 Python 的 cat([0, ...]) then cumsum, i.e. previous sum - running += sorted_scores[i]; - } - - // // 打印 cumsum_without_self(调试用) - // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] cumsum: "; - // for (size_t i = 0; i < cumsum_without_self.size(); i++) { - // std::cout << cumsum_without_self[i] << " "; - // } - // std::cout << std::endl; - - // Step6: 生成 required_sum(基于整行) - size_t offset_row_start = head_idx * q_block_num * k_block_num - + q_block_idx * k_block_num; - float row_sum = 0.0; - for (size_t k = 0; k < k_block_num; k++) { - row_sum += static_cast(blocked_attention_scores_data[offset_row_start + k]); - } - float required_sum = row_sum * m_threshold; - std::cout << "required_sum: " << required_sum << std::endl; - - - // Step7: 构建 index_mask - std::vector index_mask(cumsum_without_self.size(), false); - for (size_t i = 0; i < cumsum_without_self.size(); i++) { - index_mask[i] = (cumsum_without_self[i] < required_sum); - } - - // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] index_mask: "; - // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - // std::cout << (index_mask[head_idx][q_block_idx][k_block_idx] ? "1 " : "0 "); - // } - // std::cout << std::endl; - - - // Step8: 构建 index 向量(torch.where(index_mask, index, 0)) - std::vector index(index_mask.size(), 0); - for (size_t i = 0; i < index_mask.size(); i++) { - if (index_mask[i]) { - // 索引来源:sorted_scores[0], [1], ... 对应哪些 k_block? - // 前两个为 [0:padding], [1:diag+col0], 后续对应 other_values - if (i == 0) index[i] = 0; // dummy - else if (i == 1) index[i] = diagonal_k; - else if (i - 2 < other_values.size()) - index[i] = other_values[i - 2].second; - else - index[i] = 0; - } - } - - // Step9: 模拟 Python mask[:, torch.arange(...), index] = True - // 即对每个 (head_idx, q_block_idx),将 index[i] 对应的 k_block 置 True - for (size_t i = 0; i < index.size(); i++) { - size_t k_block_idx = index[i]; - if (index_mask[i] && k_block_idx < k_block_num) { - mask[head_idx][q_block_idx][k_block_idx] = true; - } - } - - - // 打印 cumsum_without_self(调试用) - std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] required_sum: " << required_sum << std::endl; - - // Step7: 根据 index_mask 更新 mask - // 注意:sorted_scores 带有两个前缀项,因此 other_values 对应的 sorted_scores 索引从 2 开始 - // but we must only iterate the number of other_values actually included in sorted_scores. - // size_t included_count = 0; - // if (sorted_scores.size() > 2) { - // included_count = sorted_scores.size() - 2; - // } else { - // included_count = 0; - // } - - - // // 🔹 Step10.1: 打印当前 head、q_block 的 mask - // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] mask: "; - // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - // std::cout << (mask[head_idx][q_block_idx][k_block_idx] ? "1 " : "0 "); - // } - // std::cout << std::endl; - - // for (size_t i = 0; i < included_count; ++i) { - // size_t idx_in_sorted = 2 + i; - // // 安全检查(通常应该不越界) - // if (idx_in_sorted < cumsum_without_self.size()) { - // if (cumsum_without_self[idx_in_sorted] < required_sum) { - // size_t k_block_idx = other_values[i].second; - // mask[head_idx][q_block_idx][k_block_idx] = true; - // } - // } else { - // // 如果发生越界,输出调试信息(不抛异常以便继续调试) - // std::cerr << "Debug: idx_in_sorted out of range: " << idx_in_sorted - // << " cumsum_size=" << cumsum_without_self.size() - // << " other_values.size()=" << other_values.size() - // << " sorted_scores.size()=" << sorted_scores.size() << std::endl; - // } - // } - - // // Step8: 保留左侧邻域(stride=2) - // for (int s = 1; s <= 2; s++) { - // if (q_block_idx >= static_cast(s)) { - // std::cout << head_idx << " " << q_block_idx << " " << q_block_idx - s << std::endl; - // mask[head_idx][q_block_idx][q_block_idx - s] = true; - // } - // } - - // // Step9: 保留右侧邻域(受 causal 约束) - // for (int s = 1; s <= 2; s++) { - // size_t right = q_block_idx + s; - // if (right < k_block_num) { - // std::cout << head_idx << " " << q_block_idx << " " << right << std::endl; - // mask[head_idx][q_block_idx][right] = true; - // } - // } - - // // Step10: causal mask,删除未来块 - // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - // if (k_block_idx > q_block_idx) - // mask[head_idx][q_block_idx][k_block_idx] = false; - // } - - // 🔹 Step10.1: 打印当前 head、q_block 的 mask - // std::cout << "[Head " << head_idx << " Q=" << q_block_idx << "] mask: "; - // for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - // std::cout << (mask[head_idx][q_block_idx][k_block_idx] ? "1 " : "0 "); - // } - // std::cout << std::endl; - - // Step11: 收集 mask 为 true 的块到 retval - for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - if (mask[head_idx][q_block_idx][k_block_idx]) - retval[head_idx].insert({q_block_idx, k_block_idx}); - } - } - } - - return retval; -} - -void print_attn_score_buf_with_shape(const std::shared_ptr& buf, - size_t num_heads, - size_t rows, // 实际 buf 的第2维长度 - size_t cols, // 实际 buf 的第3维长度 - size_t show_first_n_cols = 0) { // 0 表示显示全部 - std::cout << "=== Debug: attn_score_buf (shape = [" << num_heads << ", " << rows << ", " << cols << "]) ===\n"; - for (size_t h = 0; h < num_heads; ++h) { - std::cout << "Head " << h << ":\n"; - for (size_t r = 0; r < rows; ++r) { - std::cout << std::setw(3) << r << ": "; - size_t nonzero_count = 0; - size_t limit = (show_first_n_cols == 0) ? cols : std::min(cols, (size_t)show_first_n_cols); - for (size_t c = 0; c < limit; ++c) { - size_t idx = h * rows * cols + r * cols + c; - double v = static_cast(buf[idx]); - if (std::fabs(v) > 1e-12) ++nonzero_count; - std::cout << std::fixed << std::setprecision(6) << v << " "; - } - if (limit < cols) std::cout << "..."; - std::cout << " (nonzero=" << nonzero_count << ")\n"; - } - // 打印非零掩码行(帮助看 pattern) - std::cout << "Nonzero mask per row: "; - for (size_t r = 0; r < rows; ++r) { - size_t nonzero = 0; - for (size_t c = 0; c < cols; ++c) { - size_t idx = h * rows * cols + r * cols + c; - if (std::fabs(static_cast(buf[idx])) > 1e-12) { - nonzero = 1; - break; - } - } - std::cout << nonzero; - } - std::cout << "\n\n"; - } - std::cout << "=== End attn_score_buf ===\n"; -} - -void print_qk_buf(const std::shared_ptr& qk_buf, - size_t num_heads, - size_t q_block_num, - size_t k_block_num, - size_t show_first_n_cols = 0) { - std::cout << "\n=== Debug: qk_buf (shape = [" - << num_heads << ", " << q_block_num << ", " << k_block_num << "]) ===" - << std::endl; - - for (size_t h = 0; h < num_heads; ++h) { - std::cout << "Head " << h << ":\n"; - for (size_t q = 0; q < q_block_num; ++q) { - std::cout << std::setw(3) << q << ": "; - size_t limit = (show_first_n_cols == 0) - ? k_block_num - : std::min(k_block_num, (size_t)show_first_n_cols); - size_t nonzero_count = 0; - for (size_t k = 0; k < limit; ++k) { - size_t idx = h * q_block_num * k_block_num + q * k_block_num + k; - double val = static_cast(qk_buf[idx]); - if (std::fabs(val) > 1e-12) - ++nonzero_count; - std::cout << std::fixed << std::setprecision(6) << val << " "; - } - if (limit < k_block_num) - std::cout << "..."; - std::cout << " (nonzero=" << nonzero_count << ")\n"; - } - - // 打印每行是否含非零的简单掩码 - std::cout << "Nonzero mask per row: "; - for (size_t q = 0; q < q_block_num; ++q) { - bool nonzero = false; - for (size_t k = 0; k < k_block_num; ++k) { - size_t idx = h * q_block_num * k_block_num + q * k_block_num + k; - if (std::fabs(static_cast(qk_buf[idx])) > 1e-12) { - nonzero = true; - break; - } - } - std::cout << (nonzero ? "1" : "0"); - } - std::cout << "\n\n"; - } - - std::cout << "=== End of qk_buf ===\n" << std::endl; -} - -void assign_qk_buf(std::shared_ptr& qk_buf, - size_t num_heads, - size_t q_block_num, - size_t k_block_num) { - std::vector data = { - 0.1953, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.1914, 0.2695, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.2305, -0.1211, -0.1211, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - 0.0703, -0.0859, 0.2148, -0.1367, -65504.0, -65504.0, -65504.0, -65504.0, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.1367, -0.4766, -0.0039, 0.0273, 0.2031, -65504.0, -65504.0, -65504.0, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.4414, 0.0703, 0.3477, 0.4102, 0.2891, 0.4453, -65504.0, -65504.0, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.2266, -0.1797, 0.1992, 0.1523, 0.0586, 0.5234, -0.2070, -65504.0, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.3164, -0.0117, 0.0312, 0.2422, 0.3047, 0.1562, -0.1172, 0.0820, - -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - 0.4648, -0.0117, 0.1680, -0.3086, -0.2695, 0.3906, -0.1641, -0.1406, - -0.1211, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - 0.3086, -0.0156, 0.0430, -0.0938, -0.1484, 0.2773, -0.2812, 0.0039, - -0.1133, -0.2656, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - 0.5078, -0.0664, -0.2266, -0.6055, -0.2383, -0.1719, -0.0195, 0.2461, - 0.0859, -0.1680, 0.1875, -65504.0, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.4922, 0.4258, 0.2578, 0.4219, 0.0820, 0.3711, 0.4688, -0.5859, - -0.1328, 0.4102, -0.2266, 0.2695, -65504.0, -65504.0, -65504.0, -65504.0, - - -0.5586, 0.5430, 0.1211, 0.3359, -0.0859, -0.3477, 0.2500, 0.0391, - -0.1797, 0.5430, -0.2109, 0.7695, 0.1484, -65504.0, -65504.0, -65504.0, - - 0.0859, -0.1406, 0.0430, -0.1406, -0.0938, -0.2539, -0.0781, -0.0273, - -0.0820, -0.2578, 0.0469, -0.0781, -0.2227, -0.2969, -65504.0, -65504.0, - - -0.2109, -0.2539, 0.3086, 0.7109, 0.2695, 0.5547, -0.0977, -0.5430, - -0.1953, -0.3242, -0.1289, -0.0156, -0.0547, -0.5391, 0.1133, -65504.0, - - 0.0742, 0.1758, 0.2344, -0.1523, -0.2109, -0.0508, 0.0859, -0.1953, - -0.1562, 0.1680, 0.3242, 0.0195, -0.4141, -0.3164, -0.1133, 0.2383 - }; - - size_t total = num_heads * q_block_num * k_block_num; - if (data.size() != total) { - std::cerr << "Error: expected total=" << total << " but data.size=" << data.size() << std::endl; - return; - } - - // qk_buf = std::shared_ptr(new float[total]); - std::copy(data.begin(), data.end(), qk_buf.get()); -} - -void print_causal_mask_buf(const std::shared_ptr& causal_mask_buf, - size_t num_heads, - size_t q_block_num, - size_t k_block_num) { - std::cout << "=== Debug: causal_mask_buf ===" << std::endl; - - for (size_t h = 0; h < num_heads; ++h) { - std::cout << "Head " << h << ":\n"; - for (size_t q = 0; q < q_block_num; ++q) { - for (size_t k = 0; k < k_block_num; ++k) { - size_t idx = h * q_block_num * k_block_num + q * k_block_num + k; - auto val = static_cast(causal_mask_buf[idx]); - std::cout << std::setw(6) << val << " "; - } - std::cout << std::endl; - } - std::cout << std::endl; - } - - std::cout << "=== End of causal_mask_buf ===" << std::endl; -} - -void print_q_buf(const std::shared_ptr& q_buf, - size_t num_heads, - size_t q_block_num, - size_t head_dim) { - std::cout << "=== Debug: q_buf ===" << std::endl; - - for (size_t h = 0; h < num_heads; ++h) { - std::cout << "Head " << h << ":\n"; - for (size_t q = 0; q < q_block_num; ++q) { - std::cout << "Q" << std::setw(2) << q << ": "; - for (size_t d = 0; d < head_dim; ++d) { - size_t idx = h * q_block_num * head_dim + q * head_dim + d; - auto val = static_cast(q_buf[idx]); - std::cout << std::fixed << std::setprecision(4) << std::setw(8) << val << " "; - } - std::cout << std::endl; - } - std::cout << std::endl; - } - - std::cout << "=== End of q_buf ===" << std::endl; -} - -void print_k_buf(const std::shared_ptr& k_buf, - size_t num_heads, - size_t q_block_num, - size_t head_dim) { - std::cout << "=== Debug: k_buf ===" << std::endl; - - for (size_t h = 0; h < num_heads; ++h) { - std::cout << "Head " << h << ":\n"; - for (size_t q = 0; q < q_block_num; ++q) { - std::cout << "Q" << std::setw(2) << q << ": "; - for (size_t d = 0; d < head_dim; ++d) { - size_t idx = h * q_block_num * head_dim + q * head_dim + d; - auto val = static_cast(k_buf[idx]); - std::cout << std::fixed << std::setprecision(4) << std::setw(8) << val << " "; - } - std::cout << std::endl; - } - std::cout << std::endl; - } - - std::cout << "=== End of q_buf ===" << std::endl; -} - -void print_query_data(const T* data, const std::vector& shape, const std::string& name = "query_data") { - if (!data) { - std::cout << name << " is nullptr\n"; - return; - } - - std::cout << "=== " << name << " ===\n"; - - if (shape.size() == 3) { // [num_heads, q_block_num, k_block_num] - size_t H = shape[0]; - size_t Q = shape[1]; - size_t K = shape[2]; - - for (size_t h = 0; h < H; ++h) { - std::cout << "Head " << h << ":\n"; - for (size_t q = 0; q < Q; ++q) { - for (size_t k = 0; k < K; ++k) { - size_t idx = h * Q * K + q * K + k; - std::cout << std::fixed << std::setprecision(4) - << static_cast(data[idx]) << " "; - } - std::cout << "\n"; - } - std::cout << "\n"; - } - } else if (shape.size() == 4) { // [B, H, Q, K] - size_t B = shape[0]; - size_t H = shape[1]; - size_t Q = shape[2]; - size_t K = shape[3]; - - for (size_t b = 0; b < B; ++b) { - std::cout << "Batch " << b << ":\n"; - for (size_t h = 0; h < H; ++h) { - std::cout << " Head " << h << ":\n"; - for (size_t q = 0; q < Q; ++q) { - std::cout << " "; - for (size_t k = 0; k < K; ++k) { - size_t idx = b * H * Q * K + h * Q * K + q * K + k; - std::cout << std::fixed << std::setprecision(4) - << static_cast(data[idx]) << " "; - } - std::cout << "\n"; - } - std::cout << "\n"; - } - } - } else { - std::cout << "Unsupported shape size=" << shape.size() << "\n"; - } - - std::cout << "=== End of " << name << " ===\n"; -} - -void set_q_buf(std::shared_ptr &q_buf) { - const size_t B = 1; - const size_t H = 1; - const size_t Q = 32; - const size_t dim = 4; - - // tmp_data 用 float 填写你的 chunked_query 数据 - float tmp_data[B*H*Q*dim] = { - -0.3750, 1.0000, -0.2500, 0.2500, -1.0000, -0.5000, -0.1250, 0.0000, - -0.6250, -0.2500, 0.7500, 0.7500, -0.2500, 0.3750, -0.3750, -0.3750, - -0.6250, -0.7500, 0.1250, 0.1250, 1.0000, 0.7500, -0.8750, 0.1250, - 0.3750, 0.8750, -0.1250, -0.2500, 1.0000, 0.7500, 0.2500, -0.2500, - 0.1250, 0.8750, -0.8750, -0.3750, 0.6250, -0.3750, -0.1250, -1.0000, - -0.3750, 0.7500, 0.0000, 0.8750, 0.7500, 0.2500, 0.6250, -0.6250, - 0.8750, -0.2500, -0.1250, 0.7500, 0.2500, 0.3750, -0.6250, -0.7500, - -0.7500, 0.0000, -0.2500, 0.6250, -1.0000, -0.5000, -0.6250, -1.0000, - 0.8750, 0.2500, 0.5000, -0.6250, -0.1250, 0.7500, -0.7500, -0.5000, - 1.0000, -0.3750, 0.6250, 0.3750, 0.2500, 0.5000, -0.5000, 0.7500, - 0.1250, 0.0000, 0.0000, -1.0000, 0.2500, 0.6250, -0.5000, 0.8750, - -0.7500, -0.6250, 0.8750, 0.7500, 1.0000, 0.7500, 0.7500, 0.1250, - -0.5000, -1.0000, 0.0000, 0.7500, -0.8750, -0.1250, 1.0000, -0.1250, - 0.7500, 0.7500, -0.7500, -0.1250, 0.1250, -0.1250, 0.6250, 0.1250, - 0.7500, 0.6250, 0.5000, 0.8750, 1.0000, -0.6250, 0.5000, -0.6250, - 0.3750, 0.6250, -0.2500, -0.3750, -0.3750, 0.3750, 0.5000, -0.6250 - }; - - for (size_t idx = 0; idx < B*H*Q*dim; ++idx) { - q_buf[idx] = ov::float16(tmp_data[idx]); - } -} - - XAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, - const Shape& query_shape, - const T* key_data, - const Shape& key_shape) { - OPENVINO_ASSERT(query_shape.size() == 3); - OPENVINO_ASSERT(key_shape.size() == 3); - OPENVINO_ASSERT(key_shape[0] == query_shape[0]); - OPENVINO_ASSERT(key_shape[2] == query_shape[2]); - OPENVINO_ASSERT(query_shape[1] % m_stride == 0); - OPENVINO_ASSERT(key_shape[1] % m_stride == 0); - OPENVINO_ASSERT(query_shape[1] % m_block_size == 0); - OPENVINO_ASSERT(key_shape[1] % m_block_size == 0); - // print_query_data(query_data, {1, 32, 4}); - - size_t chunk_size = query_shape[1]; - size_t k_len = key_shape[1]; - size_t head_dim = query_shape[2]; - size_t num_heads = query_shape[0]; - size_t k_num_to_pad = ((k_len + chunk_size - 1) / chunk_size) * chunk_size - k_len; - Shape pad_key_shape = {num_heads, k_len + k_num_to_pad, head_dim}; - auto pad_key_buf = allocate_buf(pad_key_shape); - - for (size_t h = 0; h < num_heads; h++) - for (size_t t = 0; t < k_len; t++) - for (size_t d = 0; d < head_dim; d++) { - size_t offset = h * (k_len + k_num_to_pad) * head_dim + t * head_dim + d; - size_t original_offset = h * k_len * head_dim + t * head_dim + d; - pad_key_buf.get()[offset] = key_data[original_offset]; - } - - size_t k_chunk_num = (k_len + k_num_to_pad) / chunk_size; - size_t offset_token_chunk_num = k_chunk_num - 1; - size_t reshaped_chunk_size = chunk_size / m_stride; - // size_t reshaped_block_size = m_block_size / m_stride; - size_t k_reshaped_num_to_pad = k_num_to_pad / m_stride; - size_t k_reshaped_seq_len = (k_len + k_num_to_pad) / m_stride; - - // size_t num_blocks_per_chunk = reshaped_chunk_size / reshaped_block_size; - - // size_t q_block_num = chunk_size / m_block_size; - - // size_t k_block_num = (k_len + k_num_to_pad) / m_block_size; - - Shape reshaped_query_shape = {num_heads, query_shape[1] / m_stride, head_dim * m_stride}; - auto q_buf = allocate_buf(reshaped_query_shape); - diagonal_reshape_kdb1_no_batch(query_data, query_shape, q_buf.get(), reshaped_query_shape); - Shape reshaped_key_shape = {num_heads, pad_key_shape[1] / m_stride, head_dim * m_stride}; - auto k_buf = allocate_buf(reshaped_key_shape); - diagonal_reshape(pad_key_buf.get(), pad_key_shape, k_buf.get(), reshaped_key_shape, false); - Shape transpose_matmul_scaled_shape = {num_heads, query_shape[1] / m_stride, pad_key_shape[1] / m_stride}; - std::cout << "transpose_matmul_scaled_shape: \n"; - for (auto ii : transpose_matmul_scaled_shape) { - std::cout << ii << " "; - } - std::cout << std::endl; - auto qk_buf = allocate_buf(transpose_matmul_scaled_shape); - - - // print_q_buf(q_buf, num_heads, query_shape[1] / m_stride, head_dim * m_stride); - // set_q_buf(q_buf); - // print_q_buf(q_buf, num_heads, query_shape[1] / m_stride, head_dim * m_stride); - // print_k_buf(k_buf, num_heads, pad_key_shape[1] / m_stride, head_dim * m_stride); - transpose_matmul_scale(q_buf.get(), - k_buf.get(), - reshaped_query_shape, - reshaped_key_shape, - qk_buf.get(), - transpose_matmul_scaled_shape); - // print_qk_buf(qk_buf, num_heads, 16, 16); - - q_buf.reset(); - k_buf.reset(); - Shape causal_mask_shape = {num_heads, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num}; - auto causal_mask_buf = allocate_buf(causal_mask_shape); - std::fill(causal_mask_buf.get(), causal_mask_buf.get() + ov::shape_size(causal_mask_shape), T(0)); - if (k_reshaped_num_to_pad) { - for (size_t h = 0; h < num_heads; h++) - for (size_t q = 0; q < reshaped_chunk_size; q++) - for (size_t k = k_reshaped_seq_len - k_reshaped_num_to_pad; k < k_reshaped_seq_len; k++) { - size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + - q * (reshaped_chunk_size * k_chunk_num) + k; - - causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); - } - } - - size_t chunk_start = offset_token_chunk_num * reshaped_chunk_size; - - size_t chunk_end = chunk_start + reshaped_chunk_size; - - for (size_t h = 0; h < num_heads; h++) - for (size_t q = 0; q < reshaped_chunk_size; q++) - for (size_t k = q + 1; k < reshaped_chunk_size; k++) { - size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + - q * (reshaped_chunk_size * k_chunk_num) + chunk_start + k; - - causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); - } - - for (size_t h = 0; h < num_heads; h++) - for (size_t q = 0; q < reshaped_chunk_size; q++) - for (size_t k = chunk_end; k < reshaped_chunk_size * k_chunk_num; k++) { - size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + - q * (reshaped_chunk_size * k_chunk_num) + k; - - causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); - } - - // slice [: , : , 0 ::1 , : ] since kdb=1 - - size_t out_size = - transpose_matmul_scaled_shape[0] * transpose_matmul_scaled_shape[1] * transpose_matmul_scaled_shape[2]; - - - // print_causal_mask_buf(causal_mask_buf, num_heads, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num); - - for (size_t i = 0; i < out_size; i++) { - qk_buf.get()[i] += causal_mask_buf.get()[i]; - } - - - - causal_mask_buf.reset(); - - Shape attention_scores_shape = transpose_matmul_scaled_shape; - - auto attn_score_buf = allocate_buf(attention_scores_shape); - - // print_qk_buf(qk_buf, num_heads, 16, 16); - // assign_qk_buf(qk_buf, num_heads, 16, 16); - // print_qk_buf(qk_buf, num_heads, 16, 16); - - - softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); - - qk_buf.reset(); - - // print_attn_score_buf_with_shape(attn_score_buf, - // transpose_matmul_scaled_shape[0], - // transpose_matmul_scaled_shape[1], - // transpose_matmul_scaled_shape[2]); - - - - - size_t antidiagonals_per_xattention_block = m_block_size / m_stride; - Shape block_sum_shape = {attention_scores_shape[0], - attention_scores_shape[1] / antidiagonals_per_xattention_block, - attention_scores_shape[2] / antidiagonals_per_xattention_block}; - - auto block_sum_buf = allocate_buf(block_sum_shape); - block_sum_attention_scores(attn_score_buf.get(), attention_scores_shape, block_sum_buf.get(), block_sum_shape); - attn_score_buf.reset(); - auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); - block_sum_buf.reset(); - - // The Python has the tril on the last q_block_num - - // So, to match, the simple_masks [: , : , -q_block_num : , -q_block_num : ] = where (tril, simple_masks, False) - - // But since the return is the set, we can do in the retained, erase the upper - - // Yes, already has. - - return selected_block_indices; - } - - std::shared_ptr allocate_buf(const Shape& shape) { - return std::shared_ptr(new T[ov::shape_size(shape)]); - } - - size_t pad_to_block(size_t token_length) { - return (token_length + m_block_size - 1) / m_block_size * m_block_size; - } - - double m_threshold; - - size_t m_block_size; - - size_t m_stride; -}; - -} // namespace ov::reference \ No newline at end of file diff --git a/src/core/tests/reference/xattention.cpp b/src/core/tests/reference/xattention.cpp deleted file mode 100644 index 78ad6744c17053..00000000000000 --- a/src/core/tests/reference/xattention.cpp +++ /dev/null @@ -1,550 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include -#include - -#include - -double DEFAULT_THRESHOLD = 0.8; -size_t DEFAULT_BLOCK_SIZE = 32; -size_t DEFAULT_STRIDE = 8; - -struct E2EBlockSelectTestData { - ov::Shape q_shape; - std::vector q_data; - ov::Shape k_shape; - std::vector k_data; - double threshold; - size_t block_size; - size_t stride; -}; - -using XAttentionE2EBlockSelectTest = ::testing::TestWithParam; - -std::vector E2E_BLOCK_SELECT_TEST_CASES = {{ - {2, 4, 4}, - // clang-format off - { - 3.144, 8.512, 8.518, -8.386, - 7.889, -5.721, 5.507, 4.295, - -6.624, -8.463, 7.474, 9.879, - 4.534, -5.908, -9.388, 2.356, - - 7.497, 8.186, -8.658, -4.796, - -8.248, -9.797, -7.907, -4.513, - 3.469, 7.633, 7.244, -6.844, - -7.173, 4.450, 6.705, -7.035 - }, - // clang-format on - {2, 4, 4}, - // clang-format off - { - 3.144, 8.512, 8.518, -8.386, - 7.889, -5.721, 5.507, 4.295, - -6.624, -8.463, 7.474, 9.879, - 4.534, -5.908, -9.388, 2.356, - - 7.497, 8.186, -8.658, -4.796, - -8.248, -9.797, -7.907, -4.513, - 3.469, 7.633, 7.244, -6.844, - -7.173, 4.450, 6.705, -7.035 - }, - // clang-format on - - /* threshold = */ 0.8, - /* block_size = */ 2, - /* stride = */ 2, -}}; - -TEST_P(XAttentionE2EBlockSelectTest, SelectsBlocksWithoutThrowing) { - auto test_struct = GetParam(); - ov::reference::XAttentionBlockSelector selector(test_struct.threshold, - test_struct.block_size, - test_struct.stride); - - EXPECT_NO_THROW(selector.select_blocks(test_struct.q_data.data(), - test_struct.q_shape, - test_struct.k_data.data(), - test_struct.k_shape)); -}; - -INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionE2EBlockSelectTest, ::testing::ValuesIn(E2E_BLOCK_SELECT_TEST_CASES)); - -struct DiagonalReshapeTestData { - ov::Shape in_shape; - std::vector in_data; - bool is_antidiagonal; - size_t block_size; - size_t stride; - ov::Shape out_shape; - std::vector ref_out_data; -}; - -using XAttentionDiagonalReshapeTest = ::testing::TestWithParam; - -std::vector DIAGONAL_RESHAPE_TEST_CASES = { - { - {2, 4, 4}, - // clang-format off - { - 3.144, 8.512, 8.518, -8.386, - 7.889, -5.721, 5.507, 4.295, - -6.624, -8.463, 7.474, 9.879, - 4.534, -5.908, -9.388, 2.356, - - 7.497, 8.186, -8.658, -4.796, - -8.248, -9.797, -7.907, -4.513, - 3.469, 7.633, 7.244, -6.844, - -7.173, 4.450, 6.705, -7.035 - }, - // clang-format on - - /* is_antidiagonal = */ true, - /* block_size = */ 2, - /* stride = */ 2, - {2, 2, 8}, - - // clang-format off - { - 4.534, -5.908, -9.388, 2.356, -6.624, -8.463, 7.474, 9.879, - 7.889, -5.721, 5.507, 4.295, 3.144, 8.512, 8.518, -8.386, - - -7.173, 4.450, 6.705, -7.035, 3.469, 7.633, 7.244, -6.844, - -8.248, -9.797, -7.907, -4.513, 7.497, 8.186, -8.658, -4.796, - }, - // clang-format on - }, - { - {2, 4, 4}, - // clang-format off - { - 3.144, 8.512, 8.518, -8.386, - 7.889, -5.721, 5.507, 4.295, - -6.624, -8.463, 7.474, 9.879, - 4.534, -5.908, -9.388, 2.356, - - 7.497, 8.186, -8.658, -4.796, - -8.248, -9.797, -7.907, -4.513, - 3.469, 7.633, 7.244, -6.844, - -7.173, 4.450, 6.705, -7.035 - }, - // clang-format on - - /* is_antidiagonal = */ false, - /* block_size = */ 2, - /* stride = */ 2, - {2, 2, 8}, - - // clang-format off - { - 3.144, 8.512, 8.518, -8.386, 7.889, -5.721, 5.507, 4.295, - -6.624, -8.463, 7.474, 9.879, 4.534, -5.908, -9.388, 2.356, - - 7.497, 8.186, -8.658, -4.796, -8.248, -9.797, -7.907, -4.513, - 3.469, 7.633, 7.244, -6.844, -7.173, 4.450, 6.705, -7.035 - }, - // clang-format on - }, - { - {2, 9, 2}, - // clang-format off - { - 1.110, -4.244, - 3.530, -1.083, - 3.664, -2.459, - 3.930, -2.122, - -4.142, 2.837, - -7.413, 5.855, - 1.354, -7.748, - 0.264, 7.095, - -8.410, 6.247, - - -7.832, 9.163, - -7.414, -3.682, - -5.429, 7.854, - 1.767, 5.950, - -0.841, 1.935, - 3.568, 8.530, - 9.438, -2.421, - -5.892, 7.820, - -9.869, -7.636 - }, - // clang-format on - - /* is_antidiagonal = */ true, - /* block_size = */ 9, - /* stride = */ 3, - {2, 3, 6}, - - // clang-format off - { - -8.410, 6.247, 0.264, 7.095, 1.354, -7.748, - -7.413, 5.855, -4.142, 2.837, 3.930, -2.122, - 3.664, -2.459, 3.530, -1.083, 1.110, -4.244, - - -9.869, -7.636, -5.892, 7.820, 9.438, -2.421, - 3.568, 8.530, -0.841, 1.935, 1.767, 5.950, - -5.429, 7.854, -7.414, -3.682, -7.832, 9.163, - }, - // clang-format on - }, - { - {2, 9, 2}, - // clang-format off - { - 1.110, -4.244, - 3.530, -1.083, - 3.664, -2.459, - 3.930, -2.122, - -4.142, 2.837, - -7.413, 5.855, - 1.354, -7.748, - 0.264, 7.095, - -8.410, 6.247, - - -7.832, 9.163, - -7.414, -3.682, - -5.429, 7.854, - 1.767, 5.950, - -0.841, 1.935, - 3.568, 8.530, - 9.438, -2.421, - -5.892, 7.820, - -9.869, -7.636 - }, - // clang-format on - - /* is_antidiagonal = */ false, - /* block_size = */ 9, - /* stride = */ 3, - {2, 3, 6}, - - // clang-format off - { - 1.110, -4.244, 3.530, -1.083, 3.664, -2.459, - 3.930, -2.122, -4.142, 2.837, -7.413, 5.855, - 1.354, -7.748, 0.264, 7.095, -8.410, 6.247, - - -7.832, 9.163, -7.414, -3.682, -5.429, 7.854, - 1.767, 5.950, -0.841, 1.935, 3.568, 8.530, - 9.438, -2.421, -5.892, 7.820, -9.869, -7.636 - }, - // clang-format on - }, -}; - -TEST_P(XAttentionDiagonalReshapeTest, ReshapesDiagonally) { - auto test_struct = GetParam(); - ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); - ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); - - ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, - test_struct.block_size, - test_struct.stride); - std::vector test_out_data(test_struct.ref_out_data.size()); - selector.diagonal_reshape(test_struct.in_data.data(), - test_struct.in_shape, - test_out_data.data(), - test_struct.out_shape, - test_struct.is_antidiagonal); - EXPECT_EQ(test_out_data, test_struct.ref_out_data); -} - -INSTANTIATE_TEST_SUITE_P(VariousInputs, - XAttentionDiagonalReshapeTest, - ::testing::ValuesIn(DIAGONAL_RESHAPE_TEST_CASES)); - -struct TransposeMatmulScaleTestData { - ov::Shape reshaped_query_shape; - std::vector reshaped_query_data; - ov::Shape reshaped_key_shape; - std::vector reshaped_key_data; - size_t block_size; - size_t stride; - ov::Shape out_shape; - std::vector ref_out_data; -}; - -using XAttentionTransposeMatmulScaleTest = ::testing::TestWithParam; - -std::vector TRANSPOSE_MATMUL_SCALE_TEST_CASES = { - { - {2, 2, 8}, - // clang-format off - { - 4.534, -5.908, -9.388, 2.356, -6.624, -8.463, 7.474, 9.879, - 7.889, -5.721, 5.507, 4.295, 3.144, 8.512, 8.518, -8.386, - - -7.173, 4.450, 6.705, -7.035, 3.469, 7.633, 7.244, -6.844, - -8.248, -9.797, -7.907, -4.513, 7.497, 8.186, -8.658, -4.796, - }, - // clang-format on - - {2, 3, 8}, - - // clang-format off - { - -2.731, -0.545, 6.128, -6.175, -2.198, -1.275, -8.617, -0.683, - 3.085, 7.929, -1.127, 5.369, -6.891, 9.582, -6.954, 1.189, - -0.610, -6.310, -9.216, -1.196, 9.509, -8.119, 4.652, -4.435, - - -0.026, -9.294, 7.862, 9.318, -6.012, 8.252, -3.224, -0.710, - -2.915, -7.362, -5.553, 0.097, -4.509, 6.993, 2.021, 2.870, - -3.682, 8.637, -9.922, -6.336, -2.949, 4.339, -2.807, -9.192 - }, - - /* block_size = */ 2, - /* stride = */ 2, - {2, 2, 3}, - - // clang-format off - { - -31.760349, -21.32551225, 28.723734, - -24.15923075, -3.369805999, 3.2507255, - - -7.593187497, -4.258293245, 27.08950801, - 10.21206450, 32.95415775, 33.649577 - }, - // clang-format on - }, -}; - -TEST_P(XAttentionTransposeMatmulScaleTest, TransposesMatmulsAndScales) { - auto test_struct = GetParam(); - ASSERT_EQ(test_struct.reshaped_key_data.size(), ov::shape_size(test_struct.reshaped_key_shape)); - ASSERT_EQ(test_struct.reshaped_query_data.size(), ov::shape_size(test_struct.reshaped_query_shape)); - ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); - - ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, - test_struct.block_size, - test_struct.stride); - std::vector test_out_data(test_struct.ref_out_data.size()); - selector.transpose_matmul_scale(test_struct.reshaped_query_data.data(), - test_struct.reshaped_key_data.data(), - test_struct.reshaped_query_shape, - test_struct.reshaped_key_shape, - test_out_data.data(), - test_struct.out_shape); - - EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-8), test_struct.ref_out_data)); -} - -INSTANTIATE_TEST_SUITE_P(VariousInputs, - XAttentionTransposeMatmulScaleTest, - ::testing::ValuesIn(TRANSPOSE_MATMUL_SCALE_TEST_CASES)); - -struct SoftmaxTestData { - ov::Shape in_shape; - std::vector in_data; - ov::Shape out_shape; - std::vector ref_out_data; -}; - -using XAttentionSoftmaxTest = ::testing::TestWithParam; - -std::vector SOFTMAX_TEST_CASES = { - { - {2, 2, 4}, - // clang-format off - { - 4.534, -5.908, -9.388, 2.356, - 7.889, -5.721, 5.507, 4.295, - - -7.173, 4.450, 6.705, -7.035, - -8.248, -9.797, -7.907, -4.513 - }, - // clang-format on - - {2, 2, 4}, - - // clang-format off - { - 0.898232, 2.62111e-05, 8.07497e-07, 0.101741, - 0.892973, 1.09671e-06, 0.08248, 0.0245462, - - 8.50252e-07, 0.0949189, 0.905079, 9.76069e-07, - 0.0224685, 0.00477366, 0.0315986, 0.941159 - }, - }, -}; - -TEST_P(XAttentionSoftmaxTest, SoftmaxIsCorrect) { - auto test_struct = GetParam(); - ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); - ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); - - ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, - DEFAULT_BLOCK_SIZE, - DEFAULT_STRIDE); - std::vector test_out_data(test_struct.ref_out_data.size()); - selector.softmax(test_struct.in_data.data(), test_struct.in_shape, test_out_data.data(), test_struct.out_shape); - - EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-5), test_struct.ref_out_data)); -} - -INSTANTIATE_TEST_SUITE_P(VariousInputs, - XAttentionSoftmaxTest, - ::testing::ValuesIn(SOFTMAX_TEST_CASES)); - -struct BlockSumTestData { - ov::Shape in_shape; - std::vector in_data; - size_t block_size; - size_t stride; - ov::Shape out_shape; - std::vector ref_out_data; -}; - -using XAttentionBlockSumTest = ::testing::TestWithParam; - -std::vector BLOCK_SUM_TEST_CASES = { - { - {2, 4, 8}, - // clang-format off - { - 0.1117, 0.0780, 0.1347, 0.0885, 0.1942, 0.0922, 0.1184, 0.1824, - 0.1488, 0.1766, 0.0852, 0.1239, 0.0930, 0.1220, 0.1367, 0.1138, - 0.1410, 0.0861, 0.0774, 0.1325, 0.1478, 0.1689, 0.0885, 0.1579, - 0.1248, 0.1038, 0.1842, 0.0935, 0.1813, 0.0890, 0.0897, 0.1336, - - 0.0905, 0.1049, 0.1263, 0.0953, 0.1018, 0.1297, 0.1659, 0.1855, - 0.1373, 0.1791, 0.1005, 0.1286, 0.1492, 0.1373, 0.0820, 0.0860, - 0.0997, 0.1285, 0.0786, 0.1366, 0.1963, 0.0904, 0.1488, 0.1211, - 0.1859, 0.1174, 0.1364, 0.0930, 0.1028, 0.1034, 0.1699, 0.0912 - }, - // clang-format on - - /* block_size = */ 8, - /* stride = */ 4, - {2, 2, 4}, - - // clang-format off - { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, - - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 - }, - }, -}; -TEST_P(XAttentionBlockSumTest, BlockSumIsCorrect) { - auto test_struct = GetParam(); - ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); - ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); - - ov::reference::XAttentionBlockSelector selector(DEFAULT_THRESHOLD, - test_struct.block_size, - test_struct.stride); - std::vector test_out_data(test_struct.ref_out_data.size()); - selector.block_sum_attention_scores(test_struct.in_data.data(), test_struct.in_shape, test_out_data.data(), test_struct.out_shape); - - EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-5), test_struct.ref_out_data)); -} - -INSTANTIATE_TEST_SUITE_P(VariousInputs, - XAttentionBlockSumTest, - ::testing::ValuesIn(BLOCK_SUM_TEST_CASES)); - -struct BlockSelectTestData { - ov::Shape in_shape; - std::vector in_data; - double threshold; - ov::reference::XAttentionRetainedBlockIndicesForAllHeads ref_retained_block_indices; -}; - -using XAttentionBlockSelectTest = ::testing::TestWithParam; - -std::vector BLOCK_SELECT_TEST_CASES = { - { - {2, 2, 4}, - // clang-format off - { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, - - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 - }, - // clang-format on - /* threshold = */ 0.25, - { - {{1, 2}, {0, 3}}, - {{1, 0}, {1, 3}}, - }}, - - {{2, 2, 4}, - // clang-format off - { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, - - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 - }, - // clang-format on - /* threshold = */ 0.35, - { - {{1, 2}, {0, 3}, {0, 0}}, - {{1, 0}, {1, 3}, {0, 3}}, - }}, - {{2, 2, 4}, - // clang-format off - { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, - - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 - }, - // clang-format on - /* threshold = */ 0.1, - { - {{1, 2}}, - {{1, 0}}, - }}, - {{2, 2, 4}, - // clang-format off - { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, - - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 - }, - // clang-format on - /* threshold = */ 0.0, - { - {}, - {}, - }}, - {{2, 2, 4}, - // clang-format off - { - 0.5151, 0.4323, 0.5014, 0.5513, - 0.4557, 0.4876, 0.5870, 0.4697, - - 0.5118, 0.4507, 0.5180, 0.5194, - 0.5315, 0.4446, 0.4929, 0.5310 - }, - // clang-format on - /* threshold = */ 1.0, - { - {{1, 2}, {0, 3}, {0, 0}, {0, 2}, {1, 1}, {1, 3}, {1, 0}, {0, 1}}, - {{1, 0}, {1, 3}, {0, 3}, {0, 2}, {0, 0}, {1, 2}, {0, 1}, {1, 1}}, - }}, -}; - -TEST_P(XAttentionBlockSelectTest, BlockSelectionIsCorrect) { - auto test_struct = GetParam(); - ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); - - ov::reference::XAttentionBlockSelector selector(test_struct.threshold, DEFAULT_BLOCK_SIZE, DEFAULT_STRIDE); - auto test_result = selector.get_block_indices_to_keep(test_struct.in_data.data(), test_struct.in_shape); - - EXPECT_EQ(test_result, test_struct.ref_retained_block_indices); -} - -INSTANTIATE_TEST_SUITE_P(VariousInputs, XAttentionBlockSelectTest, ::testing::ValuesIn(BLOCK_SELECT_TEST_CASES)); \ No newline at end of file diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index 3c970d2cdcad62..3884d46bc27658 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include "paged_attention_gpu_test.hpp" #include "random_generator.hpp" From fdbba78d19a90e1e316bf5a5af622ecb334bdaa0 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 14 Oct 2025 22:31:31 +0800 Subject: [PATCH 64/96] Clean code --- .../intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp | 5 +++++ .../tests/unit/test_utils/paged_attention_gpu_test.hpp | 1 - 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index 3884d46bc27658..b9438980a1e376 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -12,6 +12,11 @@ #include #include +#include "openvino/reference/divide.hpp" +#include "openvino/reference/matmul.hpp" +#include "openvino/reference/softmax.hpp" +#include "openvino/reference/transpose.hpp" +#include "openvino/runtime/tensor.hpp" #include "paged_attention_gpu_test.hpp" #include "random_generator.hpp" #include "test_utils.h" diff --git a/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp b/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp index bd234c5bd42aeb..c5dd8e20e89b3e 100644 --- a/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp +++ b/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp @@ -14,7 +14,6 @@ #include #include #include -#include using namespace cldnn; using namespace ov::intel_gpu; From 2ade1e1abb5e823c9a162a0901f74252e32c6e52 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 14 Oct 2025 22:56:29 +0800 Subject: [PATCH 65/96] Clean code --- .../unit/test_cases/xattention_gpu_test.cpp | 82 +++++++++---------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index b9438980a1e376..a6be28bad3e47d 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -11,12 +11,12 @@ #include #include #include - #include "openvino/reference/divide.hpp" #include "openvino/reference/matmul.hpp" #include "openvino/reference/softmax.hpp" #include "openvino/reference/transpose.hpp" #include "openvino/runtime/tensor.hpp" + #include "paged_attention_gpu_test.hpp" #include "random_generator.hpp" #include "test_utils.h" @@ -27,7 +27,8 @@ using namespace ::tests; using Shape = std::vector; -using CMXAttentionBlockIndex = std::pair; // .first is the *query* dimension block index, .second is *key* +using CMXAttentionBlockIndex = + std::pair; // .first is the *query* dimension block index, .second is *key* using CMXAttentionRetainedBlockIndices = std::set; using CMXAttentionRetainedBlockIndicesForAllHeads = std::vector; @@ -430,19 +431,19 @@ struct xAttentionReference { private: std::pair, std::vector> run_reference(const std::vector& query_data, - const std::vector& key_data, - const std::vector& value_data, - int num_queries, - int num_keys, - int num_heads, - int k_head_size, - int v_head_size, - int window_size, - int sliding_window_size, - float scale, - double threshold = 0.9, - size_t block_size = 128, - size_t stride = 16) { + const std::vector& key_data, + const std::vector& value_data, + int num_queries, + int num_keys, + int num_heads, + int k_head_size, + int v_head_size, + int window_size, + int sliding_window_size, + float scale, + double threshold = 0.9, + size_t block_size = 128, + size_t stride = 16) { auto query_shape_bfyx = ov::PartialShape{1, num_queries, num_heads, k_head_size}; auto key_shape_bfyx = ov::PartialShape{1, num_keys, num_heads, k_head_size}; auto value_shape_bfyx = ov::PartialShape{1, num_keys, num_heads, v_head_size}; @@ -486,35 +487,30 @@ struct xAttentionReference { ov::Shape key_shape_3d = {static_cast(num_heads), static_cast(num_keys), static_cast(k_head_size)}; CMXAttentionRetainedBlockIndicesForAllHeads retained_blocks; - if (num_queries >= static_cast(block_size)) { - size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; - size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; - std::vector query_padded(num_heads * padded_q * k_head_size, ov::float16(0)); - std::vector key_padded(num_heads * padded_k * k_head_size, ov::float16(0)); - - for (int h = 0; h < num_heads; ++h) { - std::copy_n(&query_data_3d[h * num_queries * k_head_size], - num_queries * k_head_size, - &query_padded[h * padded_q * k_head_size]); - std::copy_n(&key_data_3d[h * num_keys * k_head_size], - num_keys * k_head_size, - &key_padded[h * padded_k * k_head_size]); - } + if (num_queries >= static_cast(block_size)) { + size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; + size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; + std::vector query_padded(num_heads * padded_q * k_head_size, ov::float16(0)); + std::vector key_padded(num_heads * padded_k * k_head_size, ov::float16(0)); + + for (int h = 0; h < num_heads; ++h) { + std::copy_n(&query_data_3d[h * num_queries * k_head_size], num_queries * k_head_size, &query_padded[h * padded_q * k_head_size]); + std::copy_n(&key_data_3d[h * num_keys * k_head_size], num_keys * k_head_size, &key_padded[h * padded_k * k_head_size]); + } - ov::Shape query_shape_padded = {static_cast(num_heads), padded_q, static_cast(k_head_size)}; - ov::Shape key_shape_padded = {static_cast(num_heads), padded_k, static_cast(k_head_size)}; + ov::Shape query_shape_padded = {static_cast(num_heads), padded_q, static_cast(k_head_size)}; + ov::Shape key_shape_padded = {static_cast(num_heads), padded_k, static_cast(k_head_size)}; - std::vector query_padded_f32(query_padded.size()); - std::vector key_padded_f32(key_padded.size()); - for (size_t i = 0; i < query_padded.size(); ++i) - query_padded_f32[i] = static_cast(query_padded[i]); - for (size_t i = 0; i < key_padded.size(); ++i) - key_padded_f32[i] = static_cast(key_padded[i]); + std::vector query_padded_f32(query_padded.size()); + std::vector key_padded_f32(key_padded.size()); + for (size_t i = 0; i < query_padded.size(); ++i) + query_padded_f32[i] = static_cast(query_padded[i]); + for (size_t i = 0; i < key_padded.size(); ++i) + key_padded_f32[i] = static_cast(key_padded[i]); - CMXAttentionBlockSelector selector(threshold, block_size, stride); - retained_blocks = selector.select_blocks(query_padded_f32.data(), query_shape_padded, - key_padded_f32.data(), key_shape_padded); - } + CMXAttentionBlockSelector selector(threshold, block_size, stride); + retained_blocks = selector.select_blocks(query_padded_f32.data(), query_shape_padded, key_padded_f32.data(), key_shape_padded); + } auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, block_size); topology topology; @@ -961,7 +957,7 @@ struct xAttentionTest : public ::testing::TestWithParam { mismatch_count++; } } - EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.02)); + EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.04)); } if (scores_output_mem) { @@ -973,7 +969,7 @@ struct xAttentionTest : public ::testing::TestWithParam { mismatch_count++; } } - EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.02)); + EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.04)); } } }; From 4a82167d018d99b4cc2663051a7e6c03d7f9f6c6 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Wed, 15 Oct 2025 21:56:13 +0800 Subject: [PATCH 66/96] Clean code --- .../test_cases/paged_attention_gpu_test.cpp | 642 +++++++++++++- .../unit/test_cases/xattention_gpu_test.cpp | 836 +++++++++++++++--- .../test_utils/paged_attention_gpu_test.hpp | 701 --------------- 3 files changed, 1341 insertions(+), 838 deletions(-) delete mode 100644 src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 160efddcccfdec..687da05dc8bb77 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1,10 +1,9 @@ -// Copyright (C) 2025 Intel Corporation +// Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include "test_utils.h" #include "random_generator.hpp" -#include "paged_attention_gpu_test.hpp" #include #include @@ -20,6 +19,643 @@ using namespace cldnn; using namespace ov::intel_gpu; using namespace ::tests; +/* +* PagedAttention inputs: +* [0]: query +* shape: [batch_size_in_tokens, num_heads * head_size], type: f16 +* [1]: key +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [2]: value  +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [3]: key_cache +* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16 +* [4]: value_cache +* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16 +* [5]: past_lens +* shape: [batch_size_in_sequences], type: i32 +* [6]: subsequence_begins +* shape: [batch_size_in_sequences + 1], type: i32 +* [7]: block_indices +* Shape: [num_blocks], type: i32 +* [8]: block_indices_begins +* Shape: [batch_size_in_sequences + 1], type: i32 +* [9]: scale, optional +* [10]: sliding_window, optional +* [11]: alibi_slopes, optional +* [12]: max_context_len +* shape: [], type: i32 +* [13]: score_aggregation_window​, optional​, shape: [batch_size_in_sequences] +* [14]: rotated_block_indices​, optional​ +* shape: [num_rotated_blocks]​, type: i32 +* [15]: rotation_deltas​, optional​ +* shape: [num_rotated_blocks, BLOCK_SIZE]​ || [num_rotated_blocks, 1]​, type: i32 +* [16]: rotation_trig_lut​, optional​ +* shape: [max_num_batched_tokens / BLOCK_SIZE, head_size]​ || [max_num_batched_tokens, head_size], type: f16 +*/ + + +enum class ScoresMode { + DISABLED = 0, + LAST_TOKEN, + SNAPKV +}; + +struct SubsequenceDescriptor { + int num_tokens; + int past_len; +}; + +struct CacheRotationDescriptor { + bool apply_rotation; + // configures 2nd dimension of rotation_deltas + // if per_block is true, single value is used for all tokens inside the block + // otherwise, each token uses an independent value + bool per_block; +}; + +struct PagedAttentionManager { + int num_heads; + int k_head_size; + int v_head_size; + int block_size; + int sliding_window_size; + bool kv_cache_compression; + ov::internal::CacheQuantMode key_cache_quant_mode; + bool has_score_aggregation; + CacheRotationDescriptor rotation_config; + std::vector subsequence_descs; + + // per-subsequence QKV inputs + std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} + std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} + std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} + + // common PA inputs + std::vector past_lens; + std::vector subsequence_begins; + std::vector block_indices; + std::vector block_indices_begins; + std::vector max_context_len; + std::vector score_aggregation_window; + + // score aggregation related inputs + std::vector score_aggregation; + + // rotation related inputs + std::vector rotated_block_indices; + std::vector rotation_deltas; + std::vector rotation_trig_lut; + + std::vector xattention_threshold; + std::vector xattention_block_size; + std::vector xattention_stride; + + cldnn::engine& test_engine; + cldnn::stream& test_stream; + tests::random_generator& rg; + + PagedAttentionManager(tests::random_generator& rg, + cldnn::engine& engine, + cldnn::stream& stream, + const std::vector& subsequence_descs, + int num_heads, + int k_head_size, + int v_head_size, + int block_size, + int sliding_window_size, + bool kv_cache_compression, + ov::internal::CacheQuantMode key_cache_quant_mode, + bool has_score_aggregation, + CacheRotationDescriptor rotation_config) + : num_heads(num_heads) + , k_head_size(k_head_size) + , v_head_size(v_head_size) + , block_size(block_size) + , sliding_window_size(sliding_window_size) + , kv_cache_compression(kv_cache_compression) + , key_cache_quant_mode(key_cache_quant_mode) + , has_score_aggregation(has_score_aggregation) + , rotation_config(rotation_config) + , subsequence_descs(subsequence_descs) + , test_engine(engine) + , test_stream(stream) + , rg(rg) { + // init subsequence_begins and block_indices_begins + subsequence_begins.push_back(0); + block_indices_begins.push_back(0); + + int max_len = 0; + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + const auto& subsequence_desc = subsequence_descs[i]; + max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); + + query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); + key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); + value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); + + past_lens.push_back(subsequence_desc.past_len); + int subsequence_start_pos = subsequence_begins[i]; + int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens; + subsequence_begins.push_back(subsequence_end_pos); + + int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len; + int required_blocks = ceil_div(subsequence_length, block_size); + int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1; + int end_block_idx = start_block_idx + required_blocks; + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + block_indices.push_back(block_idx); + } + + int block_indices_start_pos = block_indices_begins[i]; + int block_indices_end_pos = block_indices_start_pos + required_blocks; + block_indices_begins.push_back(block_indices_end_pos); + } + max_context_len.push_back(max_len); + + if (rotation_config.apply_rotation) { + // iterate over KV-cache blocks and apply cache rotation to every second + // fully occupied block + for (size_t i = 0; i < subsequence_descs.size(); i++) { + const auto& subsequence_desc = subsequence_descs[i]; + int past_len = subsequence_desc.past_len; + int start_block_idx = block_indices_begins[i]; + for (int block_idx = 1; block_idx < past_len / block_size; block_idx++) { + if (block_idx % 2 != 0) { + rotated_block_indices.push_back(start_block_idx + block_idx); + } + } + } + + if (!rotated_block_indices.empty()) { + rotation_deltas = generate_rotation_deltas_data(rg, + max_context_len[0], + rotated_block_indices.size(), + block_size, + rotation_config.per_block); + rotation_trig_lut = generate_rotation_trig_lut_data(rg, max_context_len[0], k_head_size); + } + } + + if (has_score_aggregation) { + for (const auto& subsequence_desc : subsequence_descs) { + const auto max_tokens = 10; + auto max_window_size = std::min(subsequence_desc.num_tokens, max_tokens); + auto window_size = rg.generate_random_val(1, max_window_size); + score_aggregation.push_back(window_size); + } + } + } + + memory::ptr get_query_memory() { + return get_QKV_memory(query_data, k_head_size, false); + } + + memory::ptr get_key_memory() { + return get_QKV_memory(key_data, k_head_size, true); + } + + memory::ptr get_value_memory() { + return get_QKV_memory(value_data, v_head_size, true); + } + +#if ENABLE_PA_CM_PATH + memory::ptr get_key_cache_memory() { + auto key_cache_dt = data_types::f16; + auto adjusted_head_size = k_head_size; + if (kv_cache_compression) { + key_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; + auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(key_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * v_head_size + + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + +#else + memory::ptr get_key_cache_memory() { + auto key_cache_dt = data_types::f16; + auto adjusted_head_size = k_head_size; + auto adjusted_block_size = block_size; + if (kv_cache_compression) { + key_cache_dt = data_types::i8; + if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { + adjusted_block_size += 4; + } else { + adjusted_head_size += 4; + } + } + + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, adjusted_head_size, adjusted_block_size }; + auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(key_cache_layout); + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + // quantize by channel + if (kv_cache_compression && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { + std::vector token_block(block_size); + for (int token_idx = 0; token_idx < last_token_idx; ++token_idx) { + size_t input_token_offset = block_idx * block_size + token_idx; + token_block[token_idx] = *(key_data[i].data() + input_token_offset * num_heads * k_head_size + head_idx * k_head_size + k_head_size_idx); + } + auto [quantized_data, scale, zp] = quantize_data(token_block.data(), last_token_idx, true); + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * adjusted_block_size + + head_idx * adjusted_head_size * adjusted_block_size; + size_t output_offset = output_block_offset + + k_head_size_idx * adjusted_block_size; + set_values(test_stream, memory, quantized_data.data(), last_token_idx, output_offset); + size_t comp_offset = (output_offset + block_size)/2; + set_values(test_stream, memory, &scale, 1, comp_offset); + set_values(test_stream, memory, &zp, 1, comp_offset + 1); + } + } + } + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + if (kv_cache_compression) { + if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_TOKEN) { + // quantize by token + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + // shape: [num_blocks, num_heads, adjusted_head_size, block_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * block_size + + head_idx * adjusted_head_size * block_size; + + auto [quantized_data, scale, zp] = quantize_data(data_ptr, k_head_size); + for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { + auto quantized_data_ptr = quantized_data.data() + k_head_size_idx; + + size_t output_offset = output_block_offset + + k_head_size_idx * block_size + + token_idx; + + set_values(test_stream, memory, quantized_data_ptr, 1, output_offset); + } + size_t comp_offset = (output_block_offset + k_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } + } else { + for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * k_head_size + + head_idx * k_head_size + k_head_size_idx; + + // shape: [num_blocks, num_heads, k_head_size, block_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * k_head_size * block_size + + head_idx * k_head_size * block_size + + k_head_size_idx * block_size + + token_idx; + + set_values(test_stream, memory, data_ptr, 1, output_offset); + } + } + } + } + } + } + } + + return memory; + } +#endif + + memory::ptr get_value_cache_memory() { + auto value_cache_dt = data_types::f16; + auto adjusted_head_size = v_head_size; + if (kv_cache_compression) { + value_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; + auto value_cache_layout = layout{ value_cache_shape, value_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(value_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = value_data[i].data() + + input_token_offset * num_heads * v_head_size + + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + + memory::ptr get_past_lens_memory() { + return get_memory_from_vec(past_lens); + } + + memory::ptr get_subsequence_begins_memory() { + return get_memory_from_vec(subsequence_begins); + } + + memory::ptr get_block_indices_memory() { + return get_memory_from_vec(block_indices); + } + + memory::ptr get_block_indices_begins_memory() { + return get_memory_from_vec(block_indices_begins); + } + + memory::ptr get_scale_memory() { + std::vector scale = { ov::float16(get_default_scale()) }; + return get_memory_from_vec(scale); + } + + memory::ptr get_sliding_window_memory() { + std::vector sliding_window = { 0 }; + return get_memory_from_vec(sliding_window); + } + + memory::ptr get_alibi_memory() { + std::vector alibi; + return get_memory_from_vec(alibi); + } + + memory::ptr get_max_context_len_memory() { + return get_memory_from_vec(max_context_len); + } + + memory::ptr get_score_aggregation() { + return get_memory_from_vec(score_aggregation); + } + + memory::ptr get_rotated_block_indices_memory() { + return get_memory_from_vec(rotated_block_indices); + } + + memory::ptr get_rotation_deltas_memory() { + auto mem = get_memory_from_vec(rotation_deltas); + auto layout = mem->get_layout(); + auto last_dim = rotation_config.per_block ? 1 : block_size; + layout.set_partial_shape(ov::PartialShape{ static_cast(rotated_block_indices.size()), last_dim }); + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_rotation_trig_lut_memory() { + auto mem = get_memory_from_vec(rotation_trig_lut); + auto layout = mem->get_layout(); + layout.set_partial_shape(ov::PartialShape{ max_context_len[0], k_head_size }); + + if (rotated_block_indices.empty()) { + auto empty_layout = mem->get_layout(); + empty_layout.set_partial_shape(ov::PartialShape{ 0, k_head_size }); + return test_engine.reinterpret_buffer(*mem, empty_layout); + } + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_xattention_threshold_memory() { + auto mem = get_memory_from_vec(xattention_threshold); + auto layout = mem->get_layout(); + layout.set_partial_shape(ov::PartialShape{ 1 }); + + if (xattention_threshold.empty()) { + auto empty_layout = mem->get_layout(); + empty_layout.set_partial_shape(ov::PartialShape{ 0 }); + return test_engine.reinterpret_buffer(*mem, empty_layout); + } + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_xattention_block_size_memory() { + return get_memory_from_vec(xattention_block_size); + } + + memory::ptr get_xattention_stride_memory() { + return get_memory_from_vec(xattention_stride); + } + + float get_default_scale() { + return static_cast(1.f / std::sqrt(k_head_size)); + } + +private: + template + memory::ptr get_memory_from_vec(std::vector& input_data) { + auto data_size = input_data.empty() ? 1 : input_data.size(); + auto shape = ov::PartialShape{ static_cast(data_size) }; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + auto memory = test_engine.allocate_memory(layout); + + if (input_data.empty()) { + auto shape = ov::PartialShape{0}; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + return test_engine.reinterpret_buffer(*memory, layout); + } + + set_values(test_stream, memory, input_data.data(), input_data.size(), 0); + + return memory; + } + + memory::ptr get_QKV_memory(std::vector>& input_data, int k_head_size, bool skip_past_len) { + int total_tokens = 0; + for (const auto& subsequence_desc : subsequence_descs) + total_tokens += subsequence_desc.num_tokens; + + auto query_shape = ov::PartialShape{ total_tokens, num_heads * k_head_size }; + auto query_layout = layout{ query_shape, data_types::f16, format::bfyx }; + auto memory = test_engine.allocate_memory(query_layout); + + for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { + for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = token_idx; + // as generated data stored in vectors includes past_len, ignore it for KV inputs + if (skip_past_len) + input_token_offset += subsequence_descs[subsequence_idx].past_len; + + ov::float16* data_ptr = input_data[subsequence_idx].data() + + input_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + + size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; + size_t output_offset = output_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + + set_values(test_stream, memory, data_ptr, k_head_size, output_offset); + } + } + } + + return memory; + } + + template + static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { + mem_lock mem_ptr(mem, stream); + for (size_t i = 0; i < size; i++) { + mem_ptr[dst_offset + i] = vals[i]; + } + } + + static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { + const size_t total_elements_num = tokens_num * num_heads * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + // test code + // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); + + return data; + } + + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { + const size_t total_elements_num = per_block ? rotated_blocks_num + : rotated_blocks_num * block_size; + auto data = rg.generate_random_1d(total_elements_num, 0, static_cast(max_tokens_num - 1)); + + return data; + } + + static std::vector generate_rotation_trig_lut_data(tests::random_generator& rg, size_t max_tokens_num, size_t k_head_size) { + const size_t total_elements_num = max_tokens_num * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + return data; + } + + static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, size_t size, bool expand_range = false) { + float min_value = std::numeric_limits::max(); + float max_value = std::numeric_limits::lowest(); + + for (size_t i = 0; i < size; i++) { + min_value = std::min((float)(data[i]), min_value); + max_value = std::max((float)(data[i]), max_value); + } + + float diff_value = 0.001; + if (max_value != min_value) + diff_value = max_value - min_value; + if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { + // compensate too small range + diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); + } + float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; + float zp = ((float)-min_value * scale) + std::numeric_limits::lowest(); + + std::vector quantized_data; + quantized_data.resize(size); + + auto convert_char_rte = [](float val) { + float rounded = std::nearbyint(val); + + if (rounded > 127.0f) { + return static_cast(127); + } else if (rounded < -128.0f) { + return static_cast(-128); + } else { + return static_cast(rounded); + } + }; + + for (size_t i = 0; i < size; i++) { + quantized_data[i] = convert_char_rte(data[i] * scale + zp); + } + + scale = 1.0f / scale; + + return std::make_tuple(quantized_data, scale, zp); + } +}; + namespace std { template <> struct hash { @@ -737,4 +1373,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{5, 10}}, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 64, 64, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token #endif -})); +})); \ No newline at end of file diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index a6be28bad3e47d..a839bec479c9c7 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -17,7 +17,6 @@ #include "openvino/reference/transpose.hpp" #include "openvino/runtime/tensor.hpp" -#include "paged_attention_gpu_test.hpp" #include "random_generator.hpp" #include "test_utils.h" @@ -25,6 +24,586 @@ using namespace cldnn; using namespace ov::intel_gpu; using namespace ::tests; +/* +* PagedAttention inputs: +* [0]: query +* shape: [batch_size_in_tokens, num_heads * head_size], type: f16 +* [1]: key +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [2]: value  +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [3]: key_cache +* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16 +* [4]: value_cache +* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16 +* [5]: past_lens +* shape: [batch_size_in_sequences], type: i32 +* [6]: subsequence_begins +* shape: [batch_size_in_sequences + 1], type: i32 +* [7]: block_indices +* Shape: [num_blocks], type: i32 +* [8]: block_indices_begins +* Shape: [batch_size_in_sequences + 1], type: i32 +* [9]: scale, optional +* [10]: sliding_window, optional +* [11]: alibi_slopes, optional +* [12]: max_context_len +* shape: [], type: i32 +* [13]: score_aggregation_window​, optional​, shape: [batch_size_in_sequences] +* [14]: rotated_block_indices​, optional​ +* shape: [num_rotated_blocks]​, type: i32 +* [15]: rotation_deltas​, optional​ +* shape: [num_rotated_blocks, BLOCK_SIZE]​ || [num_rotated_blocks, 1]​, type: i32 +* [16]: rotation_trig_lut​, optional​ +* shape: [max_num_batched_tokens / BLOCK_SIZE, head_size]​ || [max_num_batched_tokens, head_size], type: f16 +*/ + + +enum class ScoresMode { + DISABLED = 0, + LAST_TOKEN, + SNAPKV +}; + +struct SubsequenceDescriptor { + int num_tokens; + int past_len; +}; + +struct CacheRotationDescriptor { + bool apply_rotation; + // configures 2nd dimension of rotation_deltas + // if per_block is true, single value is used for all tokens inside the block + // otherwise, each token uses an independent value + bool per_block; +}; + +struct PagedAttentionManager { + int num_heads; + int k_head_size; + int v_head_size; + int block_size; + int sliding_window_size; + bool kv_cache_compression; + ov::internal::CacheQuantMode key_cache_quant_mode; + bool has_score_aggregation; + CacheRotationDescriptor rotation_config; + std::vector subsequence_descs; + + // per-subsequence QKV inputs + std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} + std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} + std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} + + // common PA inputs + std::vector past_lens; + std::vector subsequence_begins; + std::vector block_indices; + std::vector block_indices_begins; + std::vector max_context_len; + std::vector score_aggregation_window; + + // score aggregation related inputs + std::vector score_aggregation; + + // rotation related inputs + std::vector rotated_block_indices; + std::vector rotation_deltas; + std::vector rotation_trig_lut; + + std::vector xattention_threshold; + std::vector xattention_block_size; + std::vector xattention_stride; + + cldnn::engine& test_engine; + cldnn::stream& test_stream; + tests::random_generator& rg; + + PagedAttentionManager(tests::random_generator& rg, + cldnn::engine& engine, + cldnn::stream& stream, + const std::vector& subsequence_descs, + int num_heads, + int k_head_size, + int v_head_size, + int block_size, + int sliding_window_size, + bool kv_cache_compression, + ov::internal::CacheQuantMode key_cache_quant_mode, + bool has_score_aggregation, + CacheRotationDescriptor rotation_config) + : num_heads(num_heads) + , k_head_size(k_head_size) + , v_head_size(v_head_size) + , block_size(block_size) + , sliding_window_size(sliding_window_size) + , kv_cache_compression(kv_cache_compression) + , key_cache_quant_mode(key_cache_quant_mode) + , has_score_aggregation(has_score_aggregation) + , rotation_config(rotation_config) + , subsequence_descs(subsequence_descs) + , test_engine(engine) + , test_stream(stream) + , rg(rg) { + // init subsequence_begins and block_indices_begins + subsequence_begins.push_back(0); + block_indices_begins.push_back(0); + + int max_len = 0; + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + const auto& subsequence_desc = subsequence_descs[i]; + max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); + + query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); + key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); + value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); + + past_lens.push_back(subsequence_desc.past_len); + int subsequence_start_pos = subsequence_begins[i]; + int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens; + subsequence_begins.push_back(subsequence_end_pos); + + int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len; + int required_blocks = ceil_div(subsequence_length, block_size); + int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1; + int end_block_idx = start_block_idx + required_blocks; + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + block_indices.push_back(block_idx); + } + + int block_indices_start_pos = block_indices_begins[i]; + int block_indices_end_pos = block_indices_start_pos + required_blocks; + block_indices_begins.push_back(block_indices_end_pos); + } + max_context_len.push_back(max_len); + + if (rotation_config.apply_rotation) { + // iterate over KV-cache blocks and apply cache rotation to every second + // fully occupied block + for (size_t i = 0; i < subsequence_descs.size(); i++) { + const auto& subsequence_desc = subsequence_descs[i]; + int past_len = subsequence_desc.past_len; + int start_block_idx = block_indices_begins[i]; + for (int block_idx = 1; block_idx < past_len / block_size; block_idx++) { + if (block_idx % 2 != 0) { + rotated_block_indices.push_back(start_block_idx + block_idx); + } + } + } + + if (!rotated_block_indices.empty()) { + rotation_deltas = generate_rotation_deltas_data(rg, + max_context_len[0], + rotated_block_indices.size(), + block_size, + rotation_config.per_block); + rotation_trig_lut = generate_rotation_trig_lut_data(rg, max_context_len[0], k_head_size); + } + } + + if (has_score_aggregation) { + for (const auto& subsequence_desc : subsequence_descs) { + const auto max_tokens = 10; + auto max_window_size = std::min(subsequence_desc.num_tokens, max_tokens); + auto window_size = rg.generate_random_val(1, max_window_size); + score_aggregation.push_back(window_size); + } + } + } + + memory::ptr get_query_memory() { + return get_QKV_memory(query_data, k_head_size, false); + } + + memory::ptr get_key_memory() { + return get_QKV_memory(key_data, k_head_size, true); + } + + memory::ptr get_value_memory() { + return get_QKV_memory(value_data, v_head_size, true); + } + + memory::ptr get_key_cache_memory() { + auto key_cache_dt = data_types::f16; + auto adjusted_head_size = k_head_size; + if (kv_cache_compression) { + key_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; + auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(key_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * v_head_size + + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + + memory::ptr get_value_cache_memory() { + auto value_cache_dt = data_types::f16; + auto adjusted_head_size = v_head_size; + if (kv_cache_compression) { + value_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; + auto value_cache_layout = layout{ value_cache_shape, value_cache_dt, format::bfyx }; + auto memory = test_engine.allocate_memory(value_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = value_data[i].data() + + input_token_offset * num_heads * v_head_size + + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + + memory::ptr get_past_lens_memory() { + return get_memory_from_vec(past_lens); + } + + memory::ptr get_subsequence_begins_memory() { + return get_memory_from_vec(subsequence_begins); + } + + memory::ptr get_block_indices_memory() { + return get_memory_from_vec(block_indices); + } + + memory::ptr get_block_indices_begins_memory() { + return get_memory_from_vec(block_indices_begins); + } + + memory::ptr get_scale_memory() { + std::vector scale = { ov::float16(get_default_scale()) }; + return get_memory_from_vec(scale); + } + + memory::ptr get_sliding_window_memory() { + std::vector sliding_window = { 0 }; + return get_memory_from_vec(sliding_window); + } + + memory::ptr get_alibi_memory() { + std::vector alibi; + return get_memory_from_vec(alibi); + } + + memory::ptr get_max_context_len_memory() { + return get_memory_from_vec(max_context_len); + } + + memory::ptr get_score_aggregation() { + return get_memory_from_vec(score_aggregation); + } + + memory::ptr get_rotated_block_indices_memory() { + return get_memory_from_vec(rotated_block_indices); + } + + memory::ptr get_rotation_deltas_memory() { + auto mem = get_memory_from_vec(rotation_deltas); + auto layout = mem->get_layout(); + auto last_dim = rotation_config.per_block ? 1 : block_size; + layout.set_partial_shape(ov::PartialShape{ static_cast(rotated_block_indices.size()), last_dim }); + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_rotation_trig_lut_memory() { + auto mem = get_memory_from_vec(rotation_trig_lut); + auto layout = mem->get_layout(); + layout.set_partial_shape(ov::PartialShape{ max_context_len[0], k_head_size }); + + if (rotated_block_indices.empty()) { + auto empty_layout = mem->get_layout(); + empty_layout.set_partial_shape(ov::PartialShape{ 0, k_head_size }); + return test_engine.reinterpret_buffer(*mem, empty_layout); + } + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_xattention_threshold_memory() { + auto mem = get_memory_from_vec(xattention_threshold); + auto layout = mem->get_layout(); + layout.set_partial_shape(ov::PartialShape{ 1 }); + + if (xattention_threshold.empty()) { + auto empty_layout = mem->get_layout(); + empty_layout.set_partial_shape(ov::PartialShape{ 0 }); + return test_engine.reinterpret_buffer(*mem, empty_layout); + } + + return test_engine.reinterpret_buffer(*mem, layout); + } + + memory::ptr get_xattention_block_size_memory() { + return get_memory_from_vec(xattention_block_size); + } + + memory::ptr get_xattention_stride_memory() { + return get_memory_from_vec(xattention_stride); + } + + float get_default_scale() { + return static_cast(1.f / std::sqrt(k_head_size)); + } + +private: + template + memory::ptr get_memory_from_vec(std::vector& input_data) { + auto data_size = input_data.empty() ? 1 : input_data.size(); + auto shape = ov::PartialShape{ static_cast(data_size) }; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + auto memory = test_engine.allocate_memory(layout); + + if (input_data.empty()) { + auto shape = ov::PartialShape{0}; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + return test_engine.reinterpret_buffer(*memory, layout); + } + + set_values(test_stream, memory, input_data.data(), input_data.size(), 0); + + return memory; + } + + memory::ptr get_QKV_memory(std::vector>& input_data, int k_head_size, bool skip_past_len) { + int total_tokens = 0; + for (const auto& subsequence_desc : subsequence_descs) + total_tokens += subsequence_desc.num_tokens; + + auto query_shape = ov::PartialShape{ total_tokens, num_heads * k_head_size }; + auto query_layout = layout{ query_shape, data_types::f16, format::bfyx }; + auto memory = test_engine.allocate_memory(query_layout); + + for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { + for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = token_idx; + // as generated data stored in vectors includes past_len, ignore it for KV inputs + if (skip_past_len) + input_token_offset += subsequence_descs[subsequence_idx].past_len; + + ov::float16* data_ptr = input_data[subsequence_idx].data() + + input_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + + size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; + size_t output_offset = output_token_offset * num_heads * k_head_size + + head_idx * k_head_size; + + set_values(test_stream, memory, data_ptr, k_head_size, output_offset); + } + } + } + + return memory; + } + + template + static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { + mem_lock mem_ptr(mem, stream); + for (size_t i = 0; i < size; i++) { + mem_ptr[dst_offset + i] = vals[i]; + } + } + + static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { + const size_t total_elements_num = tokens_num * num_heads * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + // test code + // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); + + return data; + } + +static std::vector generate_input_data_ww( + tests::random_generator& rg, + size_t num_heads, + size_t tokens_num, + size_t k_head_size, + float stddev = 0.5f, // 控制数据分布集中程度 + bool normalize = true // 是否对每个向量做归一化 +) { + const size_t total_elements_num = tokens_num * num_heads * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + // 将均匀分布映射到近似正态分布 + for (size_t i = 0; i < total_elements_num; ++i) { + float x = static_cast(data[i]); + // Box-Muller transform for simple Gaussian-like distribution + float u1 = (x + 1.f) / 2.f; // [0,1] + float u2 = rg.generate_random_1d(1, 0.f, 1.f)[0]; // 另一个随机数 + float r = std::sqrt(-2.f * std::log(u1 + 1e-6f)) * stddev; // 避免 log(0) + float theta = 2.f * 3.1415926535f * u2; + float val = r * std::cos(theta); + data[i] = ov::float16(val); + } + + if (normalize) { + // 对每个 head 的每个 token 做 L2 归一化 + for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { + for (size_t token_idx = 0; token_idx < tokens_num; ++token_idx) { + float norm = 0.f; + for (size_t dim = 0; dim < k_head_size; ++dim) { + float val = static_cast(data[head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim]); + norm += val * val; + } + norm = std::sqrt(norm) + 1e-6f; + for (size_t dim = 0; dim < k_head_size; ++dim) { + size_t idx = head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim; + data[idx] = ov::float16(static_cast(data[idx]) / norm); + } + } + } + } + + return data; +} + + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { + const size_t total_elements_num = per_block ? rotated_blocks_num + : rotated_blocks_num * block_size; + auto data = rg.generate_random_1d(total_elements_num, 0, static_cast(max_tokens_num - 1)); + + return data; + } + + static std::vector generate_rotation_trig_lut_data(tests::random_generator& rg, size_t max_tokens_num, size_t k_head_size) { + const size_t total_elements_num = max_tokens_num * k_head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + return data; + } + + static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, size_t size, bool expand_range = false) { + float min_value = std::numeric_limits::max(); + float max_value = std::numeric_limits::lowest(); + + for (size_t i = 0; i < size; i++) { + min_value = std::min((float)(data[i]), min_value); + max_value = std::max((float)(data[i]), max_value); + } + + float diff_value = 0.001; + if (max_value != min_value) + diff_value = max_value - min_value; + if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { + // compensate too small range + diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); + } + float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; + float zp = ((float)-min_value * scale) + std::numeric_limits::lowest(); + + std::vector quantized_data; + quantized_data.resize(size); + + auto convert_char_rte = [](float val) { + float rounded = std::nearbyint(val); + + if (rounded > 127.0f) { + return static_cast(127); + } else if (rounded < -128.0f) { + return static_cast(-128); + } else { + return static_cast(rounded); + } + }; + + for (size_t i = 0; i < size; i++) { + quantized_data[i] = convert_char_rte(data[i] * scale + zp); + } + + scale = 1.0f / scale; + + return std::make_tuple(quantized_data, scale, zp); + } +}; + using Shape = std::vector; using CMXAttentionBlockIndex = @@ -39,50 +618,38 @@ class CMXAttentionBlockSelector { OPENVINO_ASSERT(m_block_size % m_stride == 0); } - void diagonal_reshape(const T* input_data, const Shape& input_shape, T* output_data, const Shape& out_shape, bool is_antidiagonal) { + void diagonal_reshape(const T* input_data, + const Shape& input_shape, + T* output_data, + const Shape& output_shape, + bool is_antidiagonal) { OPENVINO_ASSERT(input_shape.size() == 3); - OPENVINO_ASSERT(out_shape.size() == 3); - OPENVINO_ASSERT(input_shape[0] == out_shape[0]); - OPENVINO_ASSERT(input_shape[1] % m_stride == 0); - OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); - OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); - - size_t num_stride_steps = input_shape[1] / m_stride; - for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) { - size_t head_offset = head_idx * input_shape[1] * input_shape[2]; - for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) { - for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) { - size_t input_offset = head_offset; - size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2]; - if (is_antidiagonal) { - input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2]; - } else { - input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2]; - } - std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T)); - } - } - } - } - - void diagonal_reshape_kdb1_no_batch(const T* input_data, - const std::vector& input_shape, // [H, Q_orig, dim] - T* output_data, - const std::vector& output_shape) { + OPENVINO_ASSERT(output_shape.size() == 3); size_t H = input_shape[0]; size_t Q_orig = input_shape[1]; - size_t dim = input_shape[2]; + size_t D = input_shape[2]; size_t Q_new = output_shape[1]; + OPENVINO_ASSERT(Q_orig % m_stride == 0); + OPENVINO_ASSERT(Q_orig / m_stride == Q_new); + for (size_t h = 0; h < H; ++h) { - size_t head_in_offset = h * Q_orig * dim; - size_t head_out_offset = h * Q_new * m_stride * dim; + size_t head_in_offset = h * Q_orig * D; + size_t head_out_offset = h * Q_new * m_stride * D; for (size_t s = 0; s < m_stride; ++s) { for (size_t q = 0; q < Q_new; ++q) { - size_t in_idx = head_in_offset + (m_stride - 1 - s + q * m_stride) * dim; - size_t out_idx = head_out_offset + q * m_stride * dim + s * dim; - std::memcpy(output_data + out_idx, input_data + in_idx, dim * sizeof(T)); + size_t in_idx; + if (is_antidiagonal) { + // Anti-diagonal: (stride - 1 - s + q * stride) + in_idx = head_in_offset + (m_stride - 1 - s + q * m_stride) * D; + } else { + // Normal diagonal: (s + q * stride) + in_idx = head_in_offset + (s + q * m_stride) * D; + } + + size_t out_idx = head_out_offset + q * m_stride * D + s * D; + std::memcpy(output_data + out_idx, input_data + in_idx, D * sizeof(T)); } } } @@ -145,8 +712,10 @@ class CMXAttentionBlockSelector { } } - CMXAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(T* blocked_attention_scores_data, const Shape& blocked_attention_scores_shape) { - OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, "Expected shape [num_heads, q_block_num, k_block_num]"); + CMXAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(T* blocked_attention_scores_data, + const Shape& blocked_attention_scores_shape) { + OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, + "Expected shape [num_heads, q_block_num, k_block_num]"); size_t num_heads = blocked_attention_scores_shape[0]; size_t q_block_num = blocked_attention_scores_shape[1]; @@ -154,7 +723,9 @@ class CMXAttentionBlockSelector { CMXAttentionRetainedBlockIndicesForAllHeads retval(num_heads); - std::vector>> mask(num_heads, std::vector>(q_block_num, std::vector(k_block_num, false))); + std::vector>> mask( + num_heads, + std::vector>(q_block_num, std::vector(k_block_num, false))); for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { @@ -162,7 +733,7 @@ class CMXAttentionBlockSelector { if (diagonal_k < k_block_num) { mask[head_idx][q_block_idx][diagonal_k] = true; } - // Step1: Keep the first column + // Step1: First column reserved mask[head_idx][q_block_idx][0] = true; // Step2: Create other_values(masked_fill) @@ -182,6 +753,7 @@ class CMXAttentionBlockSelector { // Step4: Create cumulative_sum_without_self,cat([0, diagonal_sum, sorted_values[:-1]]) std::vector sorted_scores; sorted_scores.push_back(0.0); + // diagonal + First column score size_t offset_diag = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; float diag_score = static_cast(blocked_attention_scores_data[offset_diag]); float first_col_score = 0.0; @@ -220,7 +792,7 @@ class CMXAttentionBlockSelector { index_mask[i] = (cumsum_without_self[i] < required_sum); } - // Step8: Ceate index + // Step8: Create index std::vector index(index_mask.size(), 0); for (size_t i = 0; i < index_mask.size(); i++) { if (index_mask[i]) { @@ -235,12 +807,14 @@ class CMXAttentionBlockSelector { } } + // Step9: Get retval for (size_t i = 0; i < index.size(); i++) { size_t k_block_idx = index[i]; if (index_mask[i] && k_block_idx < k_block_num) { mask[head_idx][q_block_idx][k_block_idx] = true; } } + for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { if (mask[head_idx][q_block_idx][k_block_idx]) retval[head_idx].insert({q_block_idx, k_block_idx}); @@ -251,104 +825,105 @@ class CMXAttentionBlockSelector { return retval; } - CMXAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, const Shape& query_shape, const T* key_data, const Shape& key_shape) { - OPENVINO_ASSERT(query_shape.size() == 3); - OPENVINO_ASSERT(key_shape.size() == 3); - OPENVINO_ASSERT(key_shape[0] == query_shape[0]); - OPENVINO_ASSERT(key_shape[2] == query_shape[2]); - OPENVINO_ASSERT(query_shape[1] % m_stride == 0); - OPENVINO_ASSERT(key_shape[1] % m_stride == 0); - OPENVINO_ASSERT(query_shape[1] % m_block_size == 0); - OPENVINO_ASSERT(key_shape[1] % m_block_size == 0); - - size_t chunk_size = query_shape[1]; - size_t k_len = key_shape[1]; - size_t head_dim = query_shape[2]; - size_t num_heads = query_shape[0]; - size_t k_num_to_pad = ((k_len + chunk_size - 1) / chunk_size) * chunk_size - k_len; - Shape pad_key_shape = {num_heads, k_len + k_num_to_pad, head_dim}; - auto pad_key_buf = allocate_buf(pad_key_shape); - - for (size_t h = 0; h < num_heads; h++) - for (size_t t = 0; t < k_len; t++) - for (size_t d = 0; d < head_dim; d++) { - size_t offset = h * (k_len + k_num_to_pad) * head_dim + t * head_dim + d; - size_t original_offset = h * k_len * head_dim + t * head_dim + d; - pad_key_buf.get()[offset] = key_data[original_offset]; - } + CMXAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, + const Shape& query_shape, + const T* key_data, + const Shape& key_shape, + int chunk_size = -1) { + OPENVINO_ASSERT(query_shape.size() == 3 && key_shape.size() == 3); + OPENVINO_ASSERT(query_shape[0] == key_shape[0] && query_shape[2] == key_shape[2]); + OPENVINO_ASSERT(query_shape[1] % m_stride == 0 && key_shape[1] % m_stride == 0); + OPENVINO_ASSERT(query_shape[1] % m_block_size == 0 && key_shape[1] % m_block_size == 0); + + const size_t num_heads = query_shape[0]; + const size_t q_len = query_shape[1]; + const size_t k_len = key_shape[1]; + const size_t head_dim = query_shape[2]; + if (chunk_size == -1) chunk_size = q_len; + + auto pad_seq = [&](const T* src_data, size_t seq_len) { + size_t num_to_pad = ((seq_len + chunk_size - 1) / chunk_size) * chunk_size - seq_len; + Shape pad_shape = {num_heads, seq_len + num_to_pad, head_dim}; + auto buf = allocate_buf(pad_shape); + + for (size_t h = 0; h < num_heads; ++h) { + size_t src_off = h * seq_len * head_dim; + size_t dst_off = h * (seq_len + num_to_pad) * head_dim; + std::memcpy(buf.get() + dst_off, src_data + src_off, seq_len * head_dim * sizeof(T)); + if (num_to_pad) + std::fill(buf.get() + dst_off + seq_len * head_dim, + buf.get() + dst_off + (seq_len + num_to_pad) * head_dim, T(0)); + } + return std::make_pair(std::move(buf), pad_shape); + }; - size_t k_chunk_num = (k_len + k_num_to_pad) / chunk_size; - size_t offset_token_chunk_num = k_chunk_num - 1; - size_t reshaped_chunk_size = chunk_size / m_stride; - size_t k_reshaped_num_to_pad = k_num_to_pad / m_stride; - size_t k_reshaped_seq_len = (k_len + k_num_to_pad) / m_stride; - - Shape reshaped_query_shape = {num_heads, query_shape[1] / m_stride, head_dim * m_stride}; - auto q_buf = allocate_buf(reshaped_query_shape); - diagonal_reshape_kdb1_no_batch(query_data, query_shape, q_buf.get(), reshaped_query_shape); - Shape reshaped_key_shape = {num_heads, pad_key_shape[1] / m_stride, head_dim * m_stride}; - auto k_buf = allocate_buf(reshaped_key_shape); - diagonal_reshape(pad_key_buf.get(), pad_key_shape, k_buf.get(), reshaped_key_shape, false); - Shape transpose_matmul_scaled_shape = {num_heads, query_shape[1] / m_stride, pad_key_shape[1] / m_stride}; - auto qk_buf = allocate_buf(transpose_matmul_scaled_shape); - - transpose_matmul_scale(q_buf.get(), k_buf.get(), reshaped_query_shape, reshaped_key_shape, qk_buf.get(), transpose_matmul_scaled_shape); + // ======== Pad Query & Key ======== + auto [pad_query_buf, pad_query_shape] = pad_seq(query_data, q_len); + auto [pad_key_buf, pad_key_shape] = pad_seq(key_data, k_len); + + // ======== Diagonal Reshape ======== + const size_t reshaped_q_len = pad_query_shape[1] / m_stride; + const size_t reshaped_k_len = pad_key_shape[1] / m_stride; + Shape q_shape_r = {num_heads, reshaped_q_len, head_dim * m_stride}; + Shape k_shape_r = {num_heads, reshaped_k_len, head_dim * m_stride}; + + auto q_buf = allocate_buf(q_shape_r); + auto k_buf = allocate_buf(k_shape_r); + diagonal_reshape(pad_query_buf.get(), pad_query_shape, q_buf.get(), q_shape_r, true); + diagonal_reshape(pad_key_buf.get(), pad_key_shape, k_buf.get(), k_shape_r, false); + pad_query_buf.reset(); + pad_key_buf.reset(); + + // ======== QK^T + scale ======== + Shape qk_shape = {num_heads, reshaped_q_len, reshaped_k_len}; + auto qk_buf = allocate_buf(qk_shape); + transpose_matmul_scale(q_buf.get(), k_buf.get(), q_shape_r, k_shape_r, qk_buf.get(), qk_shape); q_buf.reset(); k_buf.reset(); - Shape causal_mask_shape = {num_heads, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num}; - auto causal_mask_buf = allocate_buf(causal_mask_shape); - std::fill(causal_mask_buf.get(), causal_mask_buf.get() + ov::shape_size(causal_mask_shape), T(0)); - if (k_reshaped_num_to_pad) { - for (size_t h = 0; h < num_heads; h++) - for (size_t q = 0; q < reshaped_chunk_size; q++) - for (size_t k = k_reshaped_seq_len - k_reshaped_num_to_pad; k < k_reshaped_seq_len; k++) { - size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num) + k; - - causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); - } - } - - size_t chunk_start = offset_token_chunk_num * reshaped_chunk_size; - size_t chunk_end = chunk_start + reshaped_chunk_size; - - for (size_t h = 0; h < num_heads; h++) { - for (size_t q = 0; q < reshaped_chunk_size; q++) { - for (size_t k = q + 1; k < reshaped_chunk_size; k++) { - size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num) + chunk_start + k; - causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); - } + // ======== Causal Mask ======== + auto causal_mask_buf = allocate_buf(qk_shape); + std::fill(causal_mask_buf.get(), causal_mask_buf.get() + ov::shape_size(qk_shape), T(0)); + const size_t reshaped_chunk_size = q_len / m_stride; + const size_t k_chunk_num = (k_len + ((k_len + chunk_size - 1) / chunk_size * chunk_size - k_len)) / q_len; + const size_t k_reshaped_seq_len = pad_key_shape[1] / m_stride; + const size_t k_reshaped_num_to_pad = pad_key_shape[1] / m_stride - k_len / m_stride; + const size_t chunk_start = (k_chunk_num - 1) * reshaped_chunk_size; + const size_t chunk_end = chunk_start + reshaped_chunk_size; + const T neg_inf = std::numeric_limits::lowest(); + + for (size_t h = 0; h < num_heads; ++h) { + for (size_t q = 0; q < reshaped_chunk_size; ++q) { + size_t base = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + + q * (reshaped_chunk_size * k_chunk_num); + + for (size_t k = k_reshaped_seq_len - k_reshaped_num_to_pad; k < k_reshaped_seq_len; ++k) + causal_mask_buf.get()[base + k] = neg_inf; + for (size_t k = q + 1; k < reshaped_chunk_size; ++k) + causal_mask_buf.get()[base + chunk_start + k] = neg_inf; + for (size_t k = chunk_end; k < reshaped_chunk_size * k_chunk_num; ++k) + causal_mask_buf.get()[base + k] = neg_inf; } } - - for (size_t h = 0; h < num_heads; h++) { - for (size_t q = 0; q < reshaped_chunk_size; q++) { - for (size_t k = chunk_end; k < reshaped_chunk_size * k_chunk_num; k++) { - size_t offset = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num) + k; - causal_mask_buf.get()[offset] = std::numeric_limits::lowest(); - } - } - } - size_t out_size = transpose_matmul_scaled_shape[0] * transpose_matmul_scaled_shape[1] * transpose_matmul_scaled_shape[2]; - - for (size_t i = 0; i < out_size; i++) { + // ======== qk += mask ======== + for (size_t i = 0; i < ov::shape_size(qk_shape); ++i) qk_buf.get()[i] += causal_mask_buf.get()[i]; - } - causal_mask_buf.reset(); - Shape attention_scores_shape = transpose_matmul_scaled_shape; - auto attn_score_buf = allocate_buf(attention_scores_shape); - softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape); - qk_buf.reset(); - size_t antidiagonals_per_xattention_block = m_block_size / m_stride; - Shape block_sum_shape = {attention_scores_shape[0], - attention_scores_shape[1] / antidiagonals_per_xattention_block, - attention_scores_shape[2] / antidiagonals_per_xattention_block}; + // ======== softmax ======== + auto attn_score_buf = allocate_buf(qk_shape); + softmax(qk_buf.get(), qk_shape, attn_score_buf.get(), qk_shape); + qk_buf.reset(); + // ======== block sum + select ======== + const size_t blocks_per_axis = m_block_size / m_stride; + Shape block_sum_shape = {num_heads, + reshaped_q_len / blocks_per_axis, + reshaped_k_len / blocks_per_axis}; auto block_sum_buf = allocate_buf(block_sum_shape); - block_sum_attention_scores(attn_score_buf.get(), attention_scores_shape, block_sum_buf.get(), block_sum_shape); + block_sum_attention_scores(attn_score_buf.get(), qk_shape, block_sum_buf.get(), block_sum_shape); attn_score_buf.reset(); + auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); block_sum_buf.reset(); @@ -784,11 +1359,7 @@ struct xAttentionTest : public ::testing::TestWithParam { query_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); key_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); value_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.v_head_size }); -#if ENABLE_PA_CM_PATH key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.k_head_size }); -#else - key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.k_head_size, p.block_size }); -#endif value_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.v_head_size }); past_lens_layout.set_partial_shape(ov::PartialShape{ -1 }); subsequence_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); @@ -1012,8 +1583,6 @@ const auto DISABLE_FA_V2 = true; INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test, ::testing::ValuesIn(std::vector{ - -#if ENABLE_PA_CM_PATH /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ xattention_test_params{ {{32, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token @@ -1021,5 +1590,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token -#endif })); diff --git a/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp b/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp deleted file mode 100644 index c5dd8e20e89b3e..00000000000000 --- a/src/plugins/intel_gpu/tests/unit/test_utils/paged_attention_gpu_test.hpp +++ /dev/null @@ -1,701 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "test_utils.h" -#include "random_generator.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace cldnn; -using namespace ov::intel_gpu; -using namespace ::tests; - -/* -* PagedAttention inputs: -* [0]: query -* shape: [batch_size_in_tokens, num_heads * head_size], type: f16 -* [1]: key -* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 -* [2]: value  -* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 -* [3]: key_cache -* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16 -* [4]: value_cache -* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16 -* [5]: past_lens -* shape: [batch_size_in_sequences], type: i32 -* [6]: subsequence_begins -* shape: [batch_size_in_sequences + 1], type: i32 -* [7]: block_indices -* Shape: [num_blocks], type: i32 -* [8]: block_indices_begins -* Shape: [batch_size_in_sequences + 1], type: i32 -* [9]: scale, optional -* [10]: sliding_window, optional -* [11]: alibi_slopes, optional -* [12]: max_context_len -* shape: [], type: i32 -* [13]: score_aggregation_window​, optional​, shape: [batch_size_in_sequences] -* [14]: rotated_block_indices​, optional​ -* shape: [num_rotated_blocks]​, type: i32 -* [15]: rotation_deltas​, optional​ -* shape: [num_rotated_blocks, BLOCK_SIZE]​ || [num_rotated_blocks, 1]​, type: i32 -* [16]: rotation_trig_lut​, optional​ -* shape: [max_num_batched_tokens / BLOCK_SIZE, head_size]​ || [max_num_batched_tokens, head_size], type: f16 -*/ - - -enum class ScoresMode { - DISABLED = 0, - LAST_TOKEN, - SNAPKV -}; - -struct SubsequenceDescriptor { - int num_tokens; - int past_len; -}; - -struct CacheRotationDescriptor { - bool apply_rotation; - // configures 2nd dimension of rotation_deltas - // if per_block is true, single value is used for all tokens inside the block - // otherwise, each token uses an independent value - bool per_block; -}; - -struct PagedAttentionManager { - int num_heads; - int k_head_size; - int v_head_size; - int block_size; - int sliding_window_size; - bool kv_cache_compression; - ov::internal::CacheQuantMode key_cache_quant_mode; - bool has_score_aggregation; - CacheRotationDescriptor rotation_config; - std::vector subsequence_descs; - - // per-subsequence QKV inputs - std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} - std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} - std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} - - // common PA inputs - std::vector past_lens; - std::vector subsequence_begins; - std::vector block_indices; - std::vector block_indices_begins; - std::vector max_context_len; - std::vector score_aggregation_window; - - // score aggregation related inputs - std::vector score_aggregation; - - // rotation related inputs - std::vector rotated_block_indices; - std::vector rotation_deltas; - std::vector rotation_trig_lut; - - std::vector xattention_threshold; - std::vector xattention_block_size; - std::vector xattention_stride; - - cldnn::engine& test_engine; - cldnn::stream& test_stream; - tests::random_generator& rg; - - PagedAttentionManager(tests::random_generator& rg, - cldnn::engine& engine, - cldnn::stream& stream, - const std::vector& subsequence_descs, - int num_heads, - int k_head_size, - int v_head_size, - int block_size, - int sliding_window_size, - bool kv_cache_compression, - ov::internal::CacheQuantMode key_cache_quant_mode, - bool has_score_aggregation, - CacheRotationDescriptor rotation_config) - : num_heads(num_heads) - , k_head_size(k_head_size) - , v_head_size(v_head_size) - , block_size(block_size) - , sliding_window_size(sliding_window_size) - , kv_cache_compression(kv_cache_compression) - , key_cache_quant_mode(key_cache_quant_mode) - , has_score_aggregation(has_score_aggregation) - , rotation_config(rotation_config) - , subsequence_descs(subsequence_descs) - , test_engine(engine) - , test_stream(stream) - , rg(rg) { - // init subsequence_begins and block_indices_begins - subsequence_begins.push_back(0); - block_indices_begins.push_back(0); - - int max_len = 0; - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - const auto& subsequence_desc = subsequence_descs[i]; - max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); - - query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); - key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); - value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); - - past_lens.push_back(subsequence_desc.past_len); - int subsequence_start_pos = subsequence_begins[i]; - int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens; - subsequence_begins.push_back(subsequence_end_pos); - - int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len; - int required_blocks = ceil_div(subsequence_length, block_size); - int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1; - int end_block_idx = start_block_idx + required_blocks; - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - block_indices.push_back(block_idx); - } - - int block_indices_start_pos = block_indices_begins[i]; - int block_indices_end_pos = block_indices_start_pos + required_blocks; - block_indices_begins.push_back(block_indices_end_pos); - } - max_context_len.push_back(max_len); - - if (rotation_config.apply_rotation) { - // iterate over KV-cache blocks and apply cache rotation to every second - // fully occupied block - for (size_t i = 0; i < subsequence_descs.size(); i++) { - const auto& subsequence_desc = subsequence_descs[i]; - int past_len = subsequence_desc.past_len; - int start_block_idx = block_indices_begins[i]; - for (int block_idx = 1; block_idx < past_len / block_size; block_idx++) { - if (block_idx % 2 != 0) { - rotated_block_indices.push_back(start_block_idx + block_idx); - } - } - } - - if (!rotated_block_indices.empty()) { - rotation_deltas = generate_rotation_deltas_data(rg, - max_context_len[0], - rotated_block_indices.size(), - block_size, - rotation_config.per_block); - rotation_trig_lut = generate_rotation_trig_lut_data(rg, max_context_len[0], k_head_size); - } - } - - if (has_score_aggregation) { - for (const auto& subsequence_desc : subsequence_descs) { - const auto max_tokens = 10; - auto max_window_size = std::min(subsequence_desc.num_tokens, max_tokens); - auto window_size = rg.generate_random_val(1, max_window_size); - score_aggregation.push_back(window_size); - } - } - } - - memory::ptr get_query_memory() { - return get_QKV_memory(query_data, k_head_size, false); - } - - memory::ptr get_key_memory() { - return get_QKV_memory(key_data, k_head_size, true); - } - - memory::ptr get_value_memory() { - return get_QKV_memory(value_data, v_head_size, true); - } - -#if ENABLE_PA_CM_PATH - memory::ptr get_key_cache_memory() { - auto key_cache_dt = data_types::f16; - auto adjusted_head_size = k_head_size; - if (kv_cache_compression) { - key_cache_dt = data_types::i8; - adjusted_head_size += 4; - } - - auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; - auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; - auto memory = test_engine.allocate_memory(key_cache_layout); - - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * v_head_size + - head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); - - // shape: [num_blocks, num_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); - - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; - - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } - } - } - } - } - } - - return memory; - } - -#else - memory::ptr get_key_cache_memory() { - auto key_cache_dt = data_types::f16; - auto adjusted_head_size = k_head_size; - auto adjusted_block_size = block_size; - if (kv_cache_compression) { - key_cache_dt = data_types::i8; - if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { - adjusted_block_size += 4; - } else { - adjusted_head_size += 4; - } - } - - auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, adjusted_head_size, adjusted_block_size }; - auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; - auto memory = test_engine.allocate_memory(key_cache_layout); - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; - // quantize by channel - if (kv_cache_compression && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { - std::vector token_block(block_size); - for (int token_idx = 0; token_idx < last_token_idx; ++token_idx) { - size_t input_token_offset = block_idx * block_size + token_idx; - token_block[token_idx] = *(key_data[i].data() + input_token_offset * num_heads * k_head_size + head_idx * k_head_size + k_head_size_idx); - } - auto [quantized_data, scale, zp] = quantize_data(token_block.data(), last_token_idx, true); - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * adjusted_block_size + - head_idx * adjusted_head_size * adjusted_block_size; - size_t output_offset = output_block_offset + - k_head_size_idx * adjusted_block_size; - set_values(test_stream, memory, quantized_data.data(), last_token_idx, output_offset); - size_t comp_offset = (output_offset + block_size)/2; - set_values(test_stream, memory, &scale, 1, comp_offset); - set_values(test_stream, memory, &zp, 1, comp_offset + 1); - } - } - } - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - if (kv_cache_compression) { - if (key_cache_quant_mode == ov::internal::CacheQuantMode::BY_TOKEN) { - // quantize by token - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * k_head_size + - head_idx * k_head_size; - // shape: [num_blocks, num_heads, adjusted_head_size, block_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * adjusted_head_size * block_size + - head_idx * adjusted_head_size * block_size; - - auto [quantized_data, scale, zp] = quantize_data(data_ptr, k_head_size); - for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { - auto quantized_data_ptr = quantized_data.data() + k_head_size_idx; - - size_t output_offset = output_block_offset + - k_head_size_idx * block_size + - token_idx; - - set_values(test_stream, memory, quantized_data_ptr, 1, output_offset); - } - size_t comp_offset = (output_block_offset + k_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } - } else { - for (int k_head_size_idx = 0; k_head_size_idx < k_head_size; k_head_size_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * k_head_size + - head_idx * k_head_size + k_head_size_idx; - - // shape: [num_blocks, num_heads, k_head_size, block_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * k_head_size * block_size + - head_idx * k_head_size * block_size + - k_head_size_idx * block_size + - token_idx; - - set_values(test_stream, memory, data_ptr, 1, output_offset); - } - } - } - } - } - } - } - - return memory; - } -#endif - - memory::ptr get_value_cache_memory() { - auto value_cache_dt = data_types::f16; - auto adjusted_head_size = v_head_size; - if (kv_cache_compression) { - value_cache_dt = data_types::i8; - adjusted_head_size += 4; - } - - auto num_blocks = block_indices.back() + 1; - auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; - auto value_cache_layout = layout{ value_cache_shape, value_cache_dt, format::bfyx }; - auto memory = test_engine.allocate_memory(value_cache_layout); - - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = value_data[i].data() + - input_token_offset * num_heads * v_head_size + - head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); - - // shape: [num_blocks, num_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); - - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; - - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } - } - } - } - } - } - - return memory; - } - - memory::ptr get_past_lens_memory() { - return get_memory_from_vec(past_lens); - } - - memory::ptr get_subsequence_begins_memory() { - return get_memory_from_vec(subsequence_begins); - } - - memory::ptr get_block_indices_memory() { - return get_memory_from_vec(block_indices); - } - - memory::ptr get_block_indices_begins_memory() { - return get_memory_from_vec(block_indices_begins); - } - - memory::ptr get_scale_memory() { - std::vector scale = { ov::float16(get_default_scale()) }; - return get_memory_from_vec(scale); - } - - memory::ptr get_sliding_window_memory() { - std::vector sliding_window = { 0 }; - return get_memory_from_vec(sliding_window); - } - - memory::ptr get_alibi_memory() { - std::vector alibi; - return get_memory_from_vec(alibi); - } - - memory::ptr get_max_context_len_memory() { - return get_memory_from_vec(max_context_len); - } - - memory::ptr get_score_aggregation() { - return get_memory_from_vec(score_aggregation); - } - - memory::ptr get_rotated_block_indices_memory() { - return get_memory_from_vec(rotated_block_indices); - } - - memory::ptr get_rotation_deltas_memory() { - auto mem = get_memory_from_vec(rotation_deltas); - auto layout = mem->get_layout(); - auto last_dim = rotation_config.per_block ? 1 : block_size; - layout.set_partial_shape(ov::PartialShape{ static_cast(rotated_block_indices.size()), last_dim }); - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_rotation_trig_lut_memory() { - auto mem = get_memory_from_vec(rotation_trig_lut); - auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{ max_context_len[0], k_head_size }); - - if (rotated_block_indices.empty()) { - auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{ 0, k_head_size }); - return test_engine.reinterpret_buffer(*mem, empty_layout); - } - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_xattention_threshold_memory() { - auto mem = get_memory_from_vec(xattention_threshold); - auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{ 1 }); - - if (xattention_threshold.empty()) { - auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{ 0 }); - return test_engine.reinterpret_buffer(*mem, empty_layout); - } - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_xattention_block_size_memory() { - return get_memory_from_vec(xattention_block_size); - } - - memory::ptr get_xattention_stride_memory() { - return get_memory_from_vec(xattention_stride); - } - - float get_default_scale() { - return static_cast(1.f / std::sqrt(k_head_size)); - } - -private: - template - memory::ptr get_memory_from_vec(std::vector& input_data) { - auto data_size = input_data.empty() ? 1 : input_data.size(); - auto shape = ov::PartialShape{ static_cast(data_size) }; - auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; - auto memory = test_engine.allocate_memory(layout); - - if (input_data.empty()) { - auto shape = ov::PartialShape{0}; - auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; - return test_engine.reinterpret_buffer(*memory, layout); - } - - set_values(test_stream, memory, input_data.data(), input_data.size(), 0); - - return memory; - } - - memory::ptr get_QKV_memory(std::vector>& input_data, int k_head_size, bool skip_past_len) { - int total_tokens = 0; - for (const auto& subsequence_desc : subsequence_descs) - total_tokens += subsequence_desc.num_tokens; - - auto query_shape = ov::PartialShape{ total_tokens, num_heads * k_head_size }; - auto query_layout = layout{ query_shape, data_types::f16, format::bfyx }; - auto memory = test_engine.allocate_memory(query_layout); - - for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { - for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - size_t input_token_offset = token_idx; - // as generated data stored in vectors includes past_len, ignore it for KV inputs - if (skip_past_len) - input_token_offset += subsequence_descs[subsequence_idx].past_len; - - ov::float16* data_ptr = input_data[subsequence_idx].data() + - input_token_offset * num_heads * k_head_size + - head_idx * k_head_size; - - size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; - size_t output_offset = output_token_offset * num_heads * k_head_size + - head_idx * k_head_size; - - set_values(test_stream, memory, data_ptr, k_head_size, output_offset); - } - } - } - - return memory; - } - - template - static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { - mem_lock mem_ptr(mem, stream); - for (size_t i = 0; i < size; i++) { - mem_ptr[dst_offset + i] = vals[i]; - } - } - - static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { - const size_t total_elements_num = tokens_num * num_heads * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - // test code - // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); - - return data; - } - -static std::vector generate_input_data_ww( - tests::random_generator& rg, - size_t num_heads, - size_t tokens_num, - size_t k_head_size, - float stddev = 0.5f, // 控制数据分布集中程度 - bool normalize = true // 是否对每个向量做归一化 -) { - const size_t total_elements_num = tokens_num * num_heads * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - // 将均匀分布映射到近似正态分布 - for (size_t i = 0; i < total_elements_num; ++i) { - float x = static_cast(data[i]); - // Box-Muller transform for simple Gaussian-like distribution - float u1 = (x + 1.f) / 2.f; // [0,1] - float u2 = rg.generate_random_1d(1, 0.f, 1.f)[0]; // 另一个随机数 - float r = std::sqrt(-2.f * std::log(u1 + 1e-6f)) * stddev; // 避免 log(0) - float theta = 2.f * 3.1415926535f * u2; - float val = r * std::cos(theta); - data[i] = ov::float16(val); - } - - if (normalize) { - // 对每个 head 的每个 token 做 L2 归一化 - for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { - for (size_t token_idx = 0; token_idx < tokens_num; ++token_idx) { - float norm = 0.f; - for (size_t dim = 0; dim < k_head_size; ++dim) { - float val = static_cast(data[head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim]); - norm += val * val; - } - norm = std::sqrt(norm) + 1e-6f; - for (size_t dim = 0; dim < k_head_size; ++dim) { - size_t idx = head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim; - data[idx] = ov::float16(static_cast(data[idx]) / norm); - } - } - } - } - - return data; -} - - static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { - const size_t total_elements_num = per_block ? rotated_blocks_num - : rotated_blocks_num * block_size; - auto data = rg.generate_random_1d(total_elements_num, 0, static_cast(max_tokens_num - 1)); - - return data; - } - - static std::vector generate_rotation_trig_lut_data(tests::random_generator& rg, size_t max_tokens_num, size_t k_head_size) { - const size_t total_elements_num = max_tokens_num * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - return data; - } - - static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, size_t size, bool expand_range = false) { - float min_value = std::numeric_limits::max(); - float max_value = std::numeric_limits::lowest(); - - for (size_t i = 0; i < size; i++) { - min_value = std::min((float)(data[i]), min_value); - max_value = std::max((float)(data[i]), max_value); - } - - float diff_value = 0.001; - if (max_value != min_value) - diff_value = max_value - min_value; - if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { - // compensate too small range - diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); - } - float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; - float zp = ((float)-min_value * scale) + std::numeric_limits::lowest(); - - std::vector quantized_data; - quantized_data.resize(size); - - auto convert_char_rte = [](float val) { - float rounded = std::nearbyint(val); - - if (rounded > 127.0f) { - return static_cast(127); - } else if (rounded < -128.0f) { - return static_cast(-128); - } else { - return static_cast(rounded); - } - }; - - for (size_t i = 0; i < size; i++) { - quantized_data[i] = convert_char_rte(data[i] * scale + zp); - } - - scale = 1.0f / scale; - - return std::make_tuple(quantized_data, scale, zp); - } -}; From 6f7dd8d7be00dcc06a8ec2abc714767a04302d65 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Wed, 15 Oct 2025 21:57:38 +0800 Subject: [PATCH 67/96] Clean code --- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 687da05dc8bb77..0657488e706f4e 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1373,4 +1373,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: paged_attention_test_params{ {{5, 10}}, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 64, 64, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token #endif -})); \ No newline at end of file +})); From 326ee44b41f9407390db4c7a6e6b90552d5f9efd Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Wed, 15 Oct 2025 22:01:21 +0800 Subject: [PATCH 68/96] Add more test cases --- .../intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index a839bec479c9c7..66c6c2cec6c246 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -1590,4 +1590,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 1023}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 1024}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 127}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 128}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 129}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 32}}, 28, 128, 128, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token })); From cfa1f3ac85bb04331db1f28c77aca4e42a81df05 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Wed, 15 Oct 2025 22:05:17 +0800 Subject: [PATCH 69/96] Clean code --- .../unit/test_cases/xattention_gpu_test.cpp | 44 ------------------- 1 file changed, 44 deletions(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index 66c6c2cec6c246..ed78e22ac50394 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -501,50 +501,6 @@ struct PagedAttentionManager { return data; } -static std::vector generate_input_data_ww( - tests::random_generator& rg, - size_t num_heads, - size_t tokens_num, - size_t k_head_size, - float stddev = 0.5f, // 控制数据分布集中程度 - bool normalize = true // 是否对每个向量做归一化 -) { - const size_t total_elements_num = tokens_num * num_heads * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - // 将均匀分布映射到近似正态分布 - for (size_t i = 0; i < total_elements_num; ++i) { - float x = static_cast(data[i]); - // Box-Muller transform for simple Gaussian-like distribution - float u1 = (x + 1.f) / 2.f; // [0,1] - float u2 = rg.generate_random_1d(1, 0.f, 1.f)[0]; // 另一个随机数 - float r = std::sqrt(-2.f * std::log(u1 + 1e-6f)) * stddev; // 避免 log(0) - float theta = 2.f * 3.1415926535f * u2; - float val = r * std::cos(theta); - data[i] = ov::float16(val); - } - - if (normalize) { - // 对每个 head 的每个 token 做 L2 归一化 - for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { - for (size_t token_idx = 0; token_idx < tokens_num; ++token_idx) { - float norm = 0.f; - for (size_t dim = 0; dim < k_head_size; ++dim) { - float val = static_cast(data[head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim]); - norm += val * val; - } - norm = std::sqrt(norm) + 1e-6f; - for (size_t dim = 0; dim < k_head_size; ++dim) { - size_t idx = head_idx * tokens_num * k_head_size + token_idx * k_head_size + dim; - data[idx] = ov::float16(static_cast(data[idx]) / norm); - } - } - } - } - - return data; -} - static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { const size_t total_elements_num = per_block ? rotated_blocks_num : rotated_blocks_num * block_size; From f402a14e6a1f7393f039d06523778bb4075ffa7c Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 15 Oct 2025 14:41:51 +0800 Subject: [PATCH 70/96] refactor: check single suquence condition --- .../src/graph/impls/cm/paged_attention.hpp | 25 ++++++++++++------- .../graph/impls/cm/paged_attention_gen.cpp | 1 - .../intel_gpu/src/graph/paged_attention.cpp | 12 ++++++++- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp index 1f3ff893011884..7d5c5693691a3e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp @@ -49,30 +49,37 @@ struct PagedAttentionImplementationManager : public ImplementationManager { const auto& info = engine.get_device_info(); // CM optimized for systolic-array architectures if (!check_cm_jit_support(engine, config) || !info.supports_immad || !config.get_use_cm()) { - GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported GPU architecture. " << std::endl; return false; } - const auto& q_layout = node.get_input_layout(0); - const auto& k_layout = node.get_input_layout(1); - const auto& v_layout = node.get_input_layout(2); + const auto& q_layout = node.get_input_layout(PagedAttentionInputIdx::QUERY); + const auto& k_layout = node.get_input_layout(PagedAttentionInputIdx::KEY); + const auto& v_layout = node.get_input_layout(PagedAttentionInputIdx::VALUE); const auto& out_layout = node.get_output_layout(0); if (!everyone_is(format::bfyx, q_layout.format, k_layout.format, v_layout.format, out_layout.format)) { - GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported qkv layout. " << std::endl; return false; } - if (!one_of(k_layout.data_type, supported_kv_types) || !one_of(v_layout.data_type, supported_kv_types)) { - GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + if (!one_of(k_layout.data_type, supported_q_types) || !one_of(v_layout.data_type, supported_q_types)) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported kv data type. " << std::endl; return false; } if (!one_of(q_layout.data_type, supported_q_types) || !one_of(out_layout.data_type, supported_q_types)) { - GPU_DEBUG_TRACE_DETAIL << __LINE__ << ": ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - false " << std::endl; + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported q/out data type. " << std::endl; return false; } - GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionImplementationManager::validate_impl() - true" << std::endl; + const auto& kcache_layout = node.get_input_layout(PagedAttentionInputIdx::KEY_CACHE); + const auto& vcache_layout = node.get_input_layout(PagedAttentionInputIdx::VALUE_CACHE); + if (!one_of(kcache_layout.data_type, supported_kv_types) || !one_of(vcache_layout.data_type, supported_kv_types)) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported kv cache data type. " << std::endl; + return false; + } + + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - true" << std::endl; return true; } }; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 84c58e595dcff6..954b6db40e8b38 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -443,7 +443,6 @@ Arguments PagedAttentionGeneratorSingleToken::get_arguments_desc(const kernel_im // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len==1 - // args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // kv_partition_num return args; } diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index ca7a3907d1ebe4..89c11ee000cfda 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -47,9 +47,19 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no const auto block_size_idx = desc->has_xattention ? 2 : 3; bool valid_block_size = key_cache_ps.is_dynamic() || - (key_cache_ps[block_size_idx].get_length() == static_cast(expected_block_size)); + (key_cache_ps[block_size_idx].get_length() == static_cast(expected_block_size)); OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation for key cache quant mode " , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[block_size_idx].get_length()); + + // TODO: as a preview feature, only single sequence is supported so far. Will remove this check once + // full function ready in near future. + if (desc->has_xattention) { + const auto& subseq_begins_ps = impl_param.get_input_layout(PagedAttentionInputIdx::SUBSEQUENCE_BEGINS).get_partial_shape(); + bool valid_subseq_count = subseq_begins_ps.is_dynamic() || + (subseq_begins_ps[0].get_length() == static_cast(2)); + OPENVINO_ASSERT(valid_subseq_count, "[GPU] Unexpected sub sequences count for XAttention. Got ", subseq_begins_ps[0].get_length() - 1); + } + std::vector output_layouts{ data_layout }; if (desc->has_scores_output()) { From 342ae59815564cae5edfb371611db3f34a429e00 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Wed, 15 Oct 2025 15:33:44 +0800 Subject: [PATCH 71/96] Avoid 2nd token perf drop due to cleanup unused K cache --- .../src/graph/impls/cm/pa_single_token.cm | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index 2cdf221c1e6042..058e4175d7091e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -219,13 +219,14 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } #else cm_load(Kt.format(), b2dK.set_block_y(kv_pos)); - if(kv_pos_end < kv_pos + KV_STEP) { - auto KmatRef = Kt.format(); - uint valid_cols = kv_pos_end - kv_pos; - uint valid_cols_vnni = valid_cols * 2; - for (int r = valid_cols_vnni; r < KV_STEP * 2; r++) - KmatRef.select(0,r) = 0.0f; - } + // Not need clean K cache: 1) col write will lead to huge perf drop; 2) softmax will clear unused scores + // if(kv_pos_end < kv_pos + KV_STEP) { + // auto KmatRef = Kt.format(); + // uint valid_cols = kv_pos_end - kv_pos; + // uint valid_cols_vnni = valid_cols * 2; + // for (int r = valid_cols_vnni; r < KV_STEP * 2; r++) + // KmatRef.select(0,r) = 0.0f; + // } #endif #else matrix temp; @@ -405,7 +406,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( prepack_to_VNNI_W2(VmatNormal, Vmat.format()); #else cm_load(Vmat[0].format(), b2dV.set_block_y(kv_pos)); - // somtimes KV cache would be filled with random Nan, so need to clean up the unused value data. + // Sometimes KV cache would be filled with random NAN(found in PTL), so need to clean up the unused value data. if(kv_pos_end - kv_pos < KV_STEP) { auto VmatRef = Vmat[0].format(); uint valid_rows = kv_pos_end - kv_pos; From 8e8b74cc3a05288b03eb32cf10325d5eb6498c06 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 15 Oct 2025 17:11:54 +0800 Subject: [PATCH 72/96] fix: if kvcache config is dynamic, which may occurs with a typo error from user, throw exception. --- .../src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp | 7 +++++++ .../intel_gpu/src/graph/registry/paged_attention_impls.cpp | 4 +--- .../intel_gpu/src/plugin/transformations_pipeline.cpp | 6 ++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp index 8fccefbc9e5eae..f0e80007d7f60c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp @@ -10,6 +10,7 @@ #include "program_node.h" #include "registry/implementation_manager.hpp" +#include "paged_attention_inst.h" using namespace cldnn; // TODO: Remove once namespaces are aligned @@ -30,6 +31,12 @@ struct PagedAttentionOpt : public ImplementationManager { ov::element::i8, }; + auto desc = node.as().get_primitive(); + if (desc->has_xattention) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false because XAttention is not supported with ocl. " << std::endl; + return false; + } + const auto& q_layout = node.get_input_layout(0); const auto& k_layout = node.get_input_layout(1); const auto& v_layout = node.get_input_layout(2); diff --git a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp index 8283553ad73700..6c68f8806a410e 100644 --- a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp +++ b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp @@ -21,10 +21,8 @@ using namespace cldnn; const std::vector>& Registry::get_implementations() { static const std::vector> impls = { -#if OV_GPU_WITH_CM - OV_GPU_CREATE_INSTANCE_CM(cm::PagedAttentionImplementationManager, shape_types::any) -#endif OV_GPU_CREATE_INSTANCE_OCL(ocl::PagedAttentionOpt, shape_types::any) + OV_GPU_CREATE_INSTANCE_CM(cm::PagedAttentionImplementationManager, shape_types::any) }; return impls; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index cd7d5295b76431..a55c57df48204c 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -558,8 +558,10 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // k: [num_blocks, num_kv_heads, block_size(256), head_size] // v: [num_blocks, num_kv_heads, block_size(256), head_size] ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; - kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); - kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); + const auto kv_cache_precision = config.get_kv_cache_precision(); + OPENVINO_ASSERT(kv_cache_precision != ov::element::dynamic, "[GPU] kv_cache precision should be specified."); + kv_cache_config.keyCachePrecision = kv_cache_precision; + kv_cache_config.valueCachePrecision = kv_cache_precision; kv_cache_config.inferencePrecision = infer_precision; if (use_xattention) { kv_cache_config.keyCacheBlockSize = cldnn::paged_attention::block_size_xattn; From 35267d39bcb0391f286c538f572ae41aa385f244 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Thu, 16 Oct 2025 09:29:18 +0800 Subject: [PATCH 73/96] Fix build errors and code style (#59) * Add Sinks to PA * Set xattention threshold * Format code style --- .../unit/test_cases/xattention_gpu_test.cpp | 396 ++++++++---------- 1 file changed, 173 insertions(+), 223 deletions(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index ed78e22ac50394..87e107b3c70403 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -11,12 +11,12 @@ #include #include #include + #include "openvino/reference/divide.hpp" #include "openvino/reference/matmul.hpp" #include "openvino/reference/softmax.hpp" #include "openvino/reference/transpose.hpp" #include "openvino/runtime/tensor.hpp" - #include "random_generator.hpp" #include "test_utils.h" @@ -24,53 +24,14 @@ using namespace cldnn; using namespace ov::intel_gpu; using namespace ::tests; -/* -* PagedAttention inputs: -* [0]: query -* shape: [batch_size_in_tokens, num_heads * head_size], type: f16 -* [1]: key -* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 -* [2]: value  -* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 -* [3]: key_cache -* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16 -* [4]: value_cache -* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16 -* [5]: past_lens -* shape: [batch_size_in_sequences], type: i32 -* [6]: subsequence_begins -* shape: [batch_size_in_sequences + 1], type: i32 -* [7]: block_indices -* Shape: [num_blocks], type: i32 -* [8]: block_indices_begins -* Shape: [batch_size_in_sequences + 1], type: i32 -* [9]: scale, optional -* [10]: sliding_window, optional -* [11]: alibi_slopes, optional -* [12]: max_context_len -* shape: [], type: i32 -* [13]: score_aggregation_window​, optional​, shape: [batch_size_in_sequences] -* [14]: rotated_block_indices​, optional​ -* shape: [num_rotated_blocks]​, type: i32 -* [15]: rotation_deltas​, optional​ -* shape: [num_rotated_blocks, BLOCK_SIZE]​ || [num_rotated_blocks, 1]​, type: i32 -* [16]: rotation_trig_lut​, optional​ -* shape: [max_num_batched_tokens / BLOCK_SIZE, head_size]​ || [max_num_batched_tokens, head_size], type: f16 -*/ - - -enum class ScoresMode { - DISABLED = 0, - LAST_TOKEN, - SNAPKV -}; +enum class XAttentionScoresMode { DISABLED = 0, LAST_TOKEN, SNAPKV }; -struct SubsequenceDescriptor { +struct XAttentionSubsequenceDescriptor { int num_tokens; int past_len; }; -struct CacheRotationDescriptor { +struct XAttentionCacheRotationDescriptor { bool apply_rotation; // configures 2nd dimension of rotation_deltas // if per_block is true, single value is used for all tokens inside the block @@ -78,7 +39,7 @@ struct CacheRotationDescriptor { bool per_block; }; -struct PagedAttentionManager { +struct XAttentionManager { int num_heads; int k_head_size; int v_head_size; @@ -87,13 +48,13 @@ struct PagedAttentionManager { bool kv_cache_compression; ov::internal::CacheQuantMode key_cache_quant_mode; bool has_score_aggregation; - CacheRotationDescriptor rotation_config; - std::vector subsequence_descs; + XAttentionCacheRotationDescriptor rotation_config; + std::vector subsequence_descs; // per-subsequence QKV inputs - std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} - std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} - std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} + std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} + std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} + std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} // common PA inputs std::vector past_lens; @@ -111,40 +72,42 @@ struct PagedAttentionManager { std::vector rotation_deltas; std::vector rotation_trig_lut; - std::vector xattention_threshold; + std::vector xattention_threshold = {0.9}; std::vector xattention_block_size; std::vector xattention_stride; + std::vector sinks; + cldnn::engine& test_engine; cldnn::stream& test_stream; tests::random_generator& rg; - PagedAttentionManager(tests::random_generator& rg, - cldnn::engine& engine, - cldnn::stream& stream, - const std::vector& subsequence_descs, - int num_heads, - int k_head_size, - int v_head_size, - int block_size, - int sliding_window_size, - bool kv_cache_compression, - ov::internal::CacheQuantMode key_cache_quant_mode, - bool has_score_aggregation, - CacheRotationDescriptor rotation_config) - : num_heads(num_heads) - , k_head_size(k_head_size) - , v_head_size(v_head_size) - , block_size(block_size) - , sliding_window_size(sliding_window_size) - , kv_cache_compression(kv_cache_compression) - , key_cache_quant_mode(key_cache_quant_mode) - , has_score_aggregation(has_score_aggregation) - , rotation_config(rotation_config) - , subsequence_descs(subsequence_descs) - , test_engine(engine) - , test_stream(stream) - , rg(rg) { + XAttentionManager(tests::random_generator& rg, + cldnn::engine& engine, + cldnn::stream& stream, + const std::vector& subsequence_descs, + int num_heads, + int k_head_size, + int v_head_size, + int block_size, + int sliding_window_size, + bool kv_cache_compression, + ov::internal::CacheQuantMode key_cache_quant_mode, + bool has_score_aggregation, + XAttentionCacheRotationDescriptor rotation_config) + : num_heads(num_heads), + k_head_size(k_head_size), + v_head_size(v_head_size), + block_size(block_size), + sliding_window_size(sliding_window_size), + kv_cache_compression(kv_cache_compression), + key_cache_quant_mode(key_cache_quant_mode), + has_score_aggregation(has_score_aggregation), + rotation_config(rotation_config), + subsequence_descs(subsequence_descs), + test_engine(engine), + test_stream(stream), + rg(rg) { // init subsequence_begins and block_indices_begins subsequence_begins.push_back(0); block_indices_begins.push_back(0); @@ -192,11 +155,7 @@ struct PagedAttentionManager { } if (!rotated_block_indices.empty()) { - rotation_deltas = generate_rotation_deltas_data(rg, - max_context_len[0], - rotated_block_indices.size(), - block_size, - rotation_config.per_block); + rotation_deltas = generate_rotation_deltas_data(rg, max_context_len[0], rotated_block_indices.size(), block_size, rotation_config.per_block); rotation_trig_lut = generate_rotation_trig_lut_data(rg, max_context_len[0], k_head_size); } } @@ -232,8 +191,8 @@ struct PagedAttentionManager { } auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; - auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto key_cache_shape = ov::PartialShape{num_blocks, num_heads, block_size, adjusted_head_size}; + auto key_cache_layout = layout{key_cache_shape, key_cache_dt, format::bfyx}; auto memory = test_engine.allocate_memory(key_cache_layout); for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { @@ -242,23 +201,19 @@ struct PagedAttentionManager { int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size : block_size; for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { for (int head_idx = 0; head_idx < num_heads; head_idx++) { size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * v_head_size + - head_idx * v_head_size; + ov::float16* data_ptr = key_data[i].data() + input_token_offset * num_heads * v_head_size + head_idx * v_head_size; if (kv_cache_compression) { auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); auto quantized_data_ptr = quantized_data.data(); // shape: [num_blocks, num_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; + size_t output_block_offset = + (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + token_idx * v_head_size; set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; @@ -267,8 +222,7 @@ struct PagedAttentionManager { } else { // shape: [num_blocks, num_heads, block_size, v_head_size] size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; + head_idx * block_size * v_head_size + token_idx * v_head_size; set_values(test_stream, memory, data_ptr, v_head_size, output_offset); } @@ -290,8 +244,8 @@ struct PagedAttentionManager { } auto num_blocks = block_indices.back() + 1; - auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; - auto value_cache_layout = layout{ value_cache_shape, value_cache_dt, format::bfyx }; + auto value_cache_shape = ov::PartialShape{num_blocks, num_heads, block_size, adjusted_head_size}; + auto value_cache_layout = layout{value_cache_shape, value_cache_dt, format::bfyx}; auto memory = test_engine.allocate_memory(value_cache_layout); for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { @@ -300,23 +254,19 @@ struct PagedAttentionManager { int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size : block_size; for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { for (int head_idx = 0; head_idx < num_heads; head_idx++) { size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = value_data[i].data() + - input_token_offset * num_heads * v_head_size + - head_idx * v_head_size; + ov::float16* data_ptr = value_data[i].data() + input_token_offset * num_heads * v_head_size + head_idx * v_head_size; if (kv_cache_compression) { auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); auto quantized_data_ptr = quantized_data.data(); // shape: [num_blocks, num_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; + size_t output_block_offset = + (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + token_idx * v_head_size; set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; @@ -325,8 +275,7 @@ struct PagedAttentionManager { } else { // shape: [num_blocks, num_heads, block_size, v_head_size] size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; + head_idx * block_size * v_head_size + token_idx * v_head_size; set_values(test_stream, memory, data_ptr, v_head_size, output_offset); } @@ -356,12 +305,12 @@ struct PagedAttentionManager { } memory::ptr get_scale_memory() { - std::vector scale = { ov::float16(get_default_scale()) }; + std::vector scale = {ov::float16(get_default_scale())}; return get_memory_from_vec(scale); } memory::ptr get_sliding_window_memory() { - std::vector sliding_window = { 0 }; + std::vector sliding_window = {0}; return get_memory_from_vec(sliding_window); } @@ -386,7 +335,7 @@ struct PagedAttentionManager { auto mem = get_memory_from_vec(rotation_deltas); auto layout = mem->get_layout(); auto last_dim = rotation_config.per_block ? 1 : block_size; - layout.set_partial_shape(ov::PartialShape{ static_cast(rotated_block_indices.size()), last_dim }); + layout.set_partial_shape(ov::PartialShape{static_cast(rotated_block_indices.size()), last_dim}); return test_engine.reinterpret_buffer(*mem, layout); } @@ -394,11 +343,11 @@ struct PagedAttentionManager { memory::ptr get_rotation_trig_lut_memory() { auto mem = get_memory_from_vec(rotation_trig_lut); auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{ max_context_len[0], k_head_size }); + layout.set_partial_shape(ov::PartialShape{max_context_len[0], k_head_size}); if (rotated_block_indices.empty()) { auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{ 0, k_head_size }); + empty_layout.set_partial_shape(ov::PartialShape{0, k_head_size}); return test_engine.reinterpret_buffer(*mem, empty_layout); } @@ -408,11 +357,10 @@ struct PagedAttentionManager { memory::ptr get_xattention_threshold_memory() { auto mem = get_memory_from_vec(xattention_threshold); auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{ 1 }); - + layout.set_partial_shape(ov::PartialShape{1}); if (xattention_threshold.empty()) { auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{ 0 }); + empty_layout.set_partial_shape(ov::PartialShape{0}); return test_engine.reinterpret_buffer(*mem, empty_layout); } @@ -427,21 +375,35 @@ struct PagedAttentionManager { return get_memory_from_vec(xattention_stride); } + memory::ptr get_sinks_memory() { + auto mem = get_memory_from_vec(sinks); + auto layout = mem->get_layout(); + layout.set_partial_shape(ov::PartialShape{1, num_heads, 1, 1}); + + if (sinks.empty()) { + auto empty_layout = mem->get_layout(); + empty_layout.set_partial_shape(ov::PartialShape{0, 0, 0, 0}); + return test_engine.reinterpret_buffer(*mem, empty_layout); + } + + return test_engine.reinterpret_buffer(*mem, layout); + } + float get_default_scale() { return static_cast(1.f / std::sqrt(k_head_size)); } private: - template + template memory::ptr get_memory_from_vec(std::vector& input_data) { auto data_size = input_data.empty() ? 1 : input_data.size(); - auto shape = ov::PartialShape{ static_cast(data_size) }; - auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + auto shape = ov::PartialShape{static_cast(data_size)}; + auto layout = cldnn::layout{shape, ov::element::from(), format::bfyx}; auto memory = test_engine.allocate_memory(layout); if (input_data.empty()) { auto shape = ov::PartialShape{0}; - auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + auto layout = cldnn::layout{shape, ov::element::from(), format::bfyx}; return test_engine.reinterpret_buffer(*memory, layout); } @@ -455,8 +417,8 @@ struct PagedAttentionManager { for (const auto& subsequence_desc : subsequence_descs) total_tokens += subsequence_desc.num_tokens; - auto query_shape = ov::PartialShape{ total_tokens, num_heads * k_head_size }; - auto query_layout = layout{ query_shape, data_types::f16, format::bfyx }; + auto query_shape = ov::PartialShape{total_tokens, num_heads * k_head_size}; + auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; auto memory = test_engine.allocate_memory(query_layout); for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { @@ -467,13 +429,10 @@ struct PagedAttentionManager { if (skip_past_len) input_token_offset += subsequence_descs[subsequence_idx].past_len; - ov::float16* data_ptr = input_data[subsequence_idx].data() + - input_token_offset * num_heads * k_head_size + - head_idx * k_head_size; + ov::float16* data_ptr = input_data[subsequence_idx].data() + input_token_offset * num_heads * k_head_size + head_idx * k_head_size; size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; - size_t output_offset = output_token_offset * num_heads * k_head_size + - head_idx * k_head_size; + size_t output_offset = output_token_offset * num_heads * k_head_size + head_idx * k_head_size; set_values(test_stream, memory, data_ptr, k_head_size, output_offset); } @@ -483,7 +442,7 @@ struct PagedAttentionManager { return memory; } - template + template static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { mem_lock mem_ptr(mem, stream); for (size_t i = 0; i < size; i++) { @@ -495,15 +454,15 @@ struct PagedAttentionManager { const size_t total_elements_num = tokens_num * num_heads * k_head_size; auto data = rg.generate_random_1d(total_elements_num, -1, 1); - // test code - // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); - return data; } - static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { - const size_t total_elements_num = per_block ? rotated_blocks_num - : rotated_blocks_num * block_size; + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, + size_t max_tokens_num, + size_t rotated_blocks_num, + size_t block_size, + bool per_block) { + const size_t total_elements_num = per_block ? rotated_blocks_num : rotated_blocks_num * block_size; auto data = rg.generate_random_1d(total_elements_num, 0, static_cast(max_tokens_num - 1)); return data; @@ -562,8 +521,7 @@ struct PagedAttentionManager { using Shape = std::vector; -using CMXAttentionBlockIndex = - std::pair; // .first is the *query* dimension block index, .second is *key* +using CMXAttentionBlockIndex = std::pair; // .first is the *query* dimension block index, .second is *key* using CMXAttentionRetainedBlockIndices = std::set; using CMXAttentionRetainedBlockIndicesForAllHeads = std::vector; @@ -574,11 +532,7 @@ class CMXAttentionBlockSelector { OPENVINO_ASSERT(m_block_size % m_stride == 0); } - void diagonal_reshape(const T* input_data, - const Shape& input_shape, - T* output_data, - const Shape& output_shape, - bool is_antidiagonal) { + void diagonal_reshape(const T* input_data, const Shape& input_shape, T* output_data, const Shape& output_shape, bool is_antidiagonal) { OPENVINO_ASSERT(input_shape.size() == 3); OPENVINO_ASSERT(output_shape.size() == 3); size_t H = input_shape[0]; @@ -668,10 +622,8 @@ class CMXAttentionBlockSelector { } } - CMXAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(T* blocked_attention_scores_data, - const Shape& blocked_attention_scores_shape) { - OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, - "Expected shape [num_heads, q_block_num, k_block_num]"); + CMXAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(T* blocked_attention_scores_data, const Shape& blocked_attention_scores_shape) { + OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, "Expected shape [num_heads, q_block_num, k_block_num]"); size_t num_heads = blocked_attention_scores_shape[0]; size_t q_block_num = blocked_attention_scores_shape[1]; @@ -679,9 +631,7 @@ class CMXAttentionBlockSelector { CMXAttentionRetainedBlockIndicesForAllHeads retval(num_heads); - std::vector>> mask( - num_heads, - std::vector>(q_block_num, std::vector(k_block_num, false))); + std::vector>> mask(num_heads, std::vector>(q_block_num, std::vector(k_block_num, false))); for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { @@ -782,10 +732,10 @@ class CMXAttentionBlockSelector { } CMXAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, - const Shape& query_shape, - const T* key_data, - const Shape& key_shape, - int chunk_size = -1) { + const Shape& query_shape, + const T* key_data, + const Shape& key_shape, + int chunk_size = -1) { OPENVINO_ASSERT(query_shape.size() == 3 && key_shape.size() == 3); OPENVINO_ASSERT(query_shape[0] == key_shape[0] && query_shape[2] == key_shape[2]); OPENVINO_ASSERT(query_shape[1] % m_stride == 0 && key_shape[1] % m_stride == 0); @@ -795,7 +745,8 @@ class CMXAttentionBlockSelector { const size_t q_len = query_shape[1]; const size_t k_len = key_shape[1]; const size_t head_dim = query_shape[2]; - if (chunk_size == -1) chunk_size = q_len; + if (chunk_size == -1) + chunk_size = static_cast(q_len); auto pad_seq = [&](const T* src_data, size_t seq_len) { size_t num_to_pad = ((seq_len + chunk_size - 1) / chunk_size) * chunk_size - seq_len; @@ -807,8 +758,7 @@ class CMXAttentionBlockSelector { size_t dst_off = h * (seq_len + num_to_pad) * head_dim; std::memcpy(buf.get() + dst_off, src_data + src_off, seq_len * head_dim * sizeof(T)); if (num_to_pad) - std::fill(buf.get() + dst_off + seq_len * head_dim, - buf.get() + dst_off + (seq_len + num_to_pad) * head_dim, T(0)); + std::fill(buf.get() + dst_off + seq_len * head_dim, buf.get() + dst_off + (seq_len + num_to_pad) * head_dim, T(0)); } return std::make_pair(std::move(buf), pad_shape); }; @@ -850,8 +800,7 @@ class CMXAttentionBlockSelector { for (size_t h = 0; h < num_heads; ++h) { for (size_t q = 0; q < reshaped_chunk_size; ++q) { - size_t base = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) - + q * (reshaped_chunk_size * k_chunk_num); + size_t base = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num); for (size_t k = k_reshaped_seq_len - k_reshaped_num_to_pad; k < k_reshaped_seq_len; ++k) causal_mask_buf.get()[base + k] = neg_inf; @@ -873,9 +822,7 @@ class CMXAttentionBlockSelector { // ======== block sum + select ======== const size_t blocks_per_axis = m_block_size / m_stride; - Shape block_sum_shape = {num_heads, - reshaped_q_len / blocks_per_axis, - reshaped_k_len / blocks_per_axis}; + Shape block_sum_shape = {num_heads, reshaped_q_len / blocks_per_axis, reshaped_k_len / blocks_per_axis}; auto block_sum_buf = allocate_buf(block_sum_shape); block_sum_attention_scores(attn_score_buf.get(), qk_shape, block_sum_buf.get(), block_sum_shape); attn_score_buf.reset(); @@ -902,7 +849,7 @@ class CMXAttentionBlockSelector { }; struct xAttentionReference { - xAttentionReference(PagedAttentionManager& pam) : pam(pam), test_engine(pam.test_engine), test_stream(pam.test_stream) {} + xAttentionReference(XAttentionManager& pam) : pam(pam), test_engine(pam.test_engine), test_stream(pam.test_stream) {} std::pair, std::vector> get_reference() { std::vector ref_data_output; @@ -1042,7 +989,7 @@ struct xAttentionReference { CMXAttentionBlockSelector selector(threshold, block_size, stride); retained_blocks = selector.select_blocks(query_padded_f32.data(), query_shape_padded, key_padded_f32.data(), key_shape_padded); } - auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, block_size); + auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, static_cast(block_size)); topology topology; topology.add(input_layout("query", query_layout), @@ -1227,7 +1174,7 @@ struct xAttentionReference { } } - PagedAttentionManager& pam; + XAttentionManager& pam; cldnn::engine& test_engine; cldnn::stream& test_stream; }; @@ -1244,19 +1191,19 @@ struct xAttentionTest : public ::testing::TestWithParam { } void execute(T& p) { - PagedAttentionManager pam(rg, - get_test_engine(), - get_test_stream(), - p.subsequences, - p.num_heads, - p.k_head_size, - p.v_head_size, - p.block_size, - p.sliding_window_size, - p.kv_cache_compression, - p.key_cache_quant_mode, - p.scores_mode == ScoresMode::SNAPKV, - p.rotation_config); + XAttentionManager pam(rg, + get_test_engine(), + get_test_stream(), + p.subsequences, + p.num_heads, + p.k_head_size, + p.v_head_size, + p.block_size, + p.sliding_window_size, + p.kv_cache_compression, + p.key_cache_quant_mode, + p.scores_mode == XAttentionScoresMode::SNAPKV, + p.rotation_config); if (p.kv_cache_compression) tolerance = 25e-3; @@ -1289,6 +1236,7 @@ struct xAttentionTest : public ::testing::TestWithParam { auto xattention_threshold_mem = pam.get_xattention_threshold_memory(); auto xattention_block_size_mem = pam.get_xattention_block_size_memory(); auto xattention_stride_mem = pam.get_xattention_stride_memory(); + auto sinks_mem = pam.get_sinks_memory(); auto query_layout = query_mem->get_layout(); auto key_layout = key_mem->get_layout(); @@ -1310,22 +1258,23 @@ struct xAttentionTest : public ::testing::TestWithParam { auto xattention_threshold_layout = xattention_threshold_mem->get_layout(); auto xattention_block_size_layout = xattention_block_size_mem->get_layout(); auto xattention_stride_layout = xattention_stride_mem->get_layout(); + auto sinks_layout = sinks_mem->get_layout(); // make layouts dynamic - query_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); - key_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.k_head_size }); - value_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.v_head_size }); - key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.k_head_size }); - value_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.v_head_size }); - past_lens_layout.set_partial_shape(ov::PartialShape{ -1 }); - subsequence_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); - block_indices_layout.set_partial_shape(ov::PartialShape{ -1 }); - block_indices_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); - score_aggregation_window_layout.set_partial_shape(ov::PartialShape{ -1 }); - rotated_block_indices_layout.set_partial_shape(ov::PartialShape{ -1 }); - rotation_deltas_layout.set_partial_shape(ov::PartialShape{ -1, -1 }); - rotation_trig_lut_layout.set_partial_shape(ov::PartialShape{ -1, p.k_head_size }); - xattention_threshold_layout.set_partial_shape(ov::PartialShape{ -1 }); + query_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads * p.k_head_size}); + key_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads * p.k_head_size}); + value_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads * p.v_head_size}); + key_cache_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads, p.block_size, p.k_head_size}); + value_cache_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads, p.block_size, p.v_head_size}); + past_lens_layout.set_partial_shape(ov::PartialShape{-1}); + subsequence_begins_layout.set_partial_shape(ov::PartialShape{-1}); + block_indices_layout.set_partial_shape(ov::PartialShape{-1}); + block_indices_begins_layout.set_partial_shape(ov::PartialShape{-1}); + score_aggregation_window_layout.set_partial_shape(ov::PartialShape{-1}); + rotated_block_indices_layout.set_partial_shape(ov::PartialShape{-1}); + rotation_deltas_layout.set_partial_shape(ov::PartialShape{-1, -1}); + rotation_trig_lut_layout.set_partial_shape(ov::PartialShape{-1, p.k_head_size}); + xattention_threshold_layout.set_partial_shape(ov::PartialShape{-1}); if (p.dynamic_paddings) { const auto padding_axis = 1; @@ -1347,8 +1296,7 @@ struct xAttentionTest : public ::testing::TestWithParam { auto query_data_shape = query_data_layout.get_shape(); for (size_t b = 0; b < query_data_shape[0]; b++) { for (size_t f = 0; f < query_data_shape[1]; f++) { - auto input_offset = - query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); + auto input_offset = query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); auto output_offset = padded_query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); @@ -1379,6 +1327,7 @@ struct xAttentionTest : public ::testing::TestWithParam { input_info("xattention_threshold"), input_info("xattention_block_size"), input_info("xattention_stride"), + input_info("sinks"), }; auto pa_prim = paged_attention("paged_attention", pa_inputs); @@ -1389,34 +1338,33 @@ struct xAttentionTest : public ::testing::TestWithParam { pa_prim.heads_num = p.num_heads; pa_prim.scale_val = pam.get_default_scale(); pa_prim.has_alibi = false; - pa_prim.num_outputs = p.scores_mode == ScoresMode::DISABLED ? 1 : 2; + pa_prim.num_outputs = p.scores_mode == XAttentionScoresMode::DISABLED ? 1 : 2; pa_prim.has_rotated_blocks = p.rotation_config.apply_rotation; - pa_prim.has_score_aggregation = p.scores_mode == ScoresMode::SNAPKV; + pa_prim.has_score_aggregation = p.scores_mode == XAttentionScoresMode::SNAPKV; pa_prim.sliding_window = p.sliding_window_size; pa_prim.is_key_by_channel = (p.key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL); + pa_prim.has_xattention = true; topology topology; - topology.add( - input_layout("query", query_layout), - input_layout("key", key_layout), - input_layout("value", value_layout), - input_layout("key_cache", key_cache_layout), - input_layout("value_cache", value_cache_layout), - input_layout("past_lens", past_lens_layout), - input_layout("subsequence_begins", subsequence_begins_layout), - input_layout("block_indices", block_indices_layout), - input_layout("block_indices_begins", block_indices_begins_layout), - input_layout("scale", scale_layout), - input_layout("sliding_window", sliding_window_layout), - input_layout("alibi", alibi_layout), - input_layout("max_context_len", max_context_len_layout), - input_layout("score_aggregation_window", score_aggregation_window_layout), - pa_prim, - reorder("output_data", input_info("paged_attention", 0), format::bfyx, data_types::f16) - ); - - if (p.scores_mode != ScoresMode::DISABLED) { + topology.add(input_layout("query", query_layout), + input_layout("key", key_layout), + input_layout("value", value_layout), + input_layout("key_cache", key_cache_layout), + input_layout("value_cache", value_cache_layout), + input_layout("past_lens", past_lens_layout), + input_layout("subsequence_begins", subsequence_begins_layout), + input_layout("block_indices", block_indices_layout), + input_layout("block_indices_begins", block_indices_begins_layout), + input_layout("scale", scale_layout), + input_layout("sliding_window", sliding_window_layout), + input_layout("alibi", alibi_layout), + input_layout("max_context_len", max_context_len_layout), + input_layout("score_aggregation_window", score_aggregation_window_layout), + pa_prim, + reorder("output_data", input_info("paged_attention", 0), format::bfyx, data_types::f16)); + + if (p.scores_mode != XAttentionScoresMode::DISABLED) { topology.add(reorder("output_scores", input_info("paged_attention", 1), format::bfyx, data_types::f16)); } @@ -1431,6 +1379,7 @@ struct xAttentionTest : public ::testing::TestWithParam { topology.add(input_layout("xattention_threshold", xattention_threshold_layout)); topology.add(input_layout("xattention_block_size", xattention_block_size_layout)); topology.add(input_layout("xattention_stride", xattention_stride_layout)); + topology.add(input_layout("sinks", sinks_layout)); } ExecutionConfig config = get_test_default_config(get_test_engine()); @@ -1460,6 +1409,7 @@ struct xAttentionTest : public ::testing::TestWithParam { network->set_input_data("xattention_threshold", xattention_threshold_mem); network->set_input_data("xattention_block_size", xattention_block_size_mem); network->set_input_data("xattention_stride", xattention_stride_mem); + network->set_input_data("sinks", sinks_mem); auto outputs = network->execute(); @@ -1467,7 +1417,7 @@ struct xAttentionTest : public ::testing::TestWithParam { cldnn::memory::ptr output_scores_mem = nullptr; output_data_mem = outputs.at("output_data").get_memory(); - if (p.scores_mode != ScoresMode::DISABLED) { + if (p.scores_mode != XAttentionScoresMode::DISABLED) { output_scores_mem = outputs.at("output_scores").get_memory(); } auto ref_data = xAttentionReference(pam).get_reference(); @@ -1502,7 +1452,7 @@ struct xAttentionTest : public ::testing::TestWithParam { }; struct xattention_test_params { - std::vector subsequences; + std::vector subsequences; int num_heads; int k_head_size; int v_head_size; @@ -1511,8 +1461,8 @@ struct xattention_test_params { bool kv_cache_compression; ov::internal::CacheQuantMode key_cache_quant_mode; bool dynamic_paddings; - ScoresMode scores_mode; - CacheRotationDescriptor rotation_config; + XAttentionScoresMode scores_mode; + XAttentionCacheRotationDescriptor rotation_config; bool disable_flashattn_v2; }; @@ -1525,12 +1475,12 @@ TEST_P(xattention_test, basic) { const auto ENABLE_CACHE_COMPRESSION = true; const auto DISABLE_CACHE_COMPRESSION = false; -const auto DISABLE_SCORES = ScoresMode::DISABLED; -const auto ENABLE_SCORES = ScoresMode::LAST_TOKEN; -const auto ENABLE_SCORES_SNAPKV = ScoresMode::SNAPKV; -const auto PER_BLOCK_ROTATION = CacheRotationDescriptor{true, true}; -const auto PER_TOKEN_ROTATION = CacheRotationDescriptor{true, false}; -const auto DISABLE_ROTATION = CacheRotationDescriptor{false, false}; +const auto DISABLE_SCORES = XAttentionScoresMode::DISABLED; +const auto ENABLE_SCORES = XAttentionScoresMode::LAST_TOKEN; +const auto ENABLE_SCORES_SNAPKV = XAttentionScoresMode::SNAPKV; +const auto PER_BLOCK_ROTATION = XAttentionCacheRotationDescriptor{true, true}; +const auto PER_TOKEN_ROTATION = XAttentionCacheRotationDescriptor{true, false}; +const auto DISABLE_ROTATION = XAttentionCacheRotationDescriptor{false, false}; const auto STATIC_INPUT_PAD = false; const auto DYNAMIC_INPUT_PAD = true; const auto ENABLE_FA_V2 = false; From 2dfbb19ff46dcc7892ec3111738014fb24952106 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Thu, 16 Oct 2025 15:37:04 +0800 Subject: [PATCH 74/96] Fix test cases and skip testing on unsupported platforms (#60) --- .../unit/test_cases/xattention_gpu_test.cpp | 254 ++++++++++-------- 1 file changed, 143 insertions(+), 111 deletions(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index 87e107b3c70403..1a189dc6447798 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -17,6 +17,7 @@ #include "openvino/reference/softmax.hpp" #include "openvino/reference/transpose.hpp" #include "openvino/runtime/tensor.hpp" +#include "primitive_inst.h" #include "random_generator.hpp" #include "test_utils.h" @@ -72,7 +73,7 @@ struct XAttentionManager { std::vector rotation_deltas; std::vector rotation_trig_lut; - std::vector xattention_threshold = {0.9}; + std::vector xattention_threshold; std::vector xattention_block_size; std::vector xattention_stride; @@ -94,7 +95,8 @@ struct XAttentionManager { bool kv_cache_compression, ov::internal::CacheQuantMode key_cache_quant_mode, bool has_score_aggregation, - XAttentionCacheRotationDescriptor rotation_config) + XAttentionCacheRotationDescriptor rotation_config, + std::vector threshold) : num_heads(num_heads), k_head_size(k_head_size), v_head_size(v_head_size), @@ -111,15 +113,18 @@ struct XAttentionManager { // init subsequence_begins and block_indices_begins subsequence_begins.push_back(0); block_indices_begins.push_back(0); + for (int i = 0; i < static_cast(threshold.size()); i++) { + xattention_threshold.emplace_back(static_cast(threshold[i])); + } int max_len = 0; for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { const auto& subsequence_desc = subsequence_descs[i]; max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); - query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); - key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); - value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); + query_data.push_back(generate_realistic_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); + key_data.push_back(generate_realistic_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); + value_data.push_back(generate_realistic_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); past_lens.push_back(subsequence_desc.past_len); int subsequence_start_pos = subsequence_begins[i]; @@ -457,6 +462,26 @@ struct XAttentionManager { return data; } + static std::vector generate_realistic_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { + std::vector data(num_heads * tokens_num * k_head_size); + + std::mt19937 gen(1234); + std::normal_distribution dist(0.0f, 0.1f); + + for (size_t h = 0; h < num_heads; ++h) { + for (size_t t = 0; t < tokens_num; ++t) { + for (size_t d = 0; d < k_head_size; ++d) { + float val = dist(gen); + if (t > 0) + val = 0.8f * val + 0.2f * static_cast(data[h * tokens_num * k_head_size + (t - 1) * k_head_size + d]); + data[h * tokens_num * k_head_size + t * k_head_size + d] = static_cast(val); + } + } + } + + return data; + } + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, @@ -849,55 +874,56 @@ class CMXAttentionBlockSelector { }; struct xAttentionReference { - xAttentionReference(XAttentionManager& pam) : pam(pam), test_engine(pam.test_engine), test_stream(pam.test_stream) {} + xAttentionReference(XAttentionManager& xam) : xam(xam), test_engine(xam.test_engine), test_stream(xam.test_stream) {} - std::pair, std::vector> get_reference() { + std::pair, std::vector> get_reference(std::vector threshold) { std::vector ref_data_output; std::vector ref_scores_output; - for (size_t i = 0; i < pam.subsequence_descs.size(); i++) { - const auto& subsequence_desc = pam.subsequence_descs[i]; + for (size_t i = 0; i < xam.subsequence_descs.size(); i++) { + const auto& subsequence_desc = xam.subsequence_descs[i]; const auto kv_seq_len = subsequence_desc.num_tokens + subsequence_desc.past_len; - auto key_data = pam.key_data[i]; - if (pam.rotation_config.apply_rotation) { - auto blocks_start = pam.block_indices_begins[i]; - auto blocks_end = pam.block_indices_begins[i + 1]; + auto key_data = xam.key_data[i]; + if (xam.rotation_config.apply_rotation) { + auto blocks_start = xam.block_indices_begins[i]; + auto blocks_end = xam.block_indices_begins[i + 1]; - std::vector block_indices(pam.block_indices.begin() + blocks_start, pam.block_indices.begin() + blocks_end); + std::vector block_indices(xam.block_indices.begin() + blocks_start, xam.block_indices.begin() + blocks_end); for (const auto& block_idx : block_indices) { - auto it = std::find(pam.rotated_block_indices.begin(), pam.rotated_block_indices.end(), block_idx); - if (it != pam.rotated_block_indices.end()) { - int index = std::distance(pam.rotated_block_indices.begin(), it); + auto it = std::find(xam.rotated_block_indices.begin(), xam.rotated_block_indices.end(), block_idx); + if (it != xam.rotated_block_indices.end()) { + int index = std::distance(xam.rotated_block_indices.begin(), it); int subsequence_rotated_block_idx = *it - blocks_start; rotate_block(key_data, - pam.rotation_deltas, - pam.rotation_trig_lut, + xam.rotation_deltas, + xam.rotation_trig_lut, index, subsequence_rotated_block_idx, - pam.num_heads, - pam.k_head_size, - pam.block_size, - pam.rotation_config.per_block); + xam.num_heads, + xam.k_head_size, + xam.block_size, + xam.rotation_config.per_block); } } } - auto window_size = pam.has_score_aggregation ? pam.score_aggregation[i] : 1; + auto window_size = xam.has_score_aggregation ? xam.score_aggregation[i] : 1; - auto subsequence_ref_results = run_reference(pam.query_data[i], + auto subsequence_ref_results = run_reference(xam.query_data[i], key_data, - pam.value_data[i], + xam.value_data[i], subsequence_desc.num_tokens, kv_seq_len, - pam.num_heads, - pam.k_head_size, - pam.v_head_size, + xam.num_heads, + xam.k_head_size, + xam.v_head_size, window_size, - pam.sliding_window_size, - pam.get_default_scale()); + xam.sliding_window_size, + xam.get_default_scale(), + static_cast(threshold[i])); // concatenate all subsequences into one vector ref_data_output.insert(ref_data_output.end(), subsequence_ref_results.first.begin(), subsequence_ref_results.first.end()); @@ -922,13 +948,9 @@ struct xAttentionReference { double threshold = 0.9, size_t block_size = 128, size_t stride = 16) { - auto query_shape_bfyx = ov::PartialShape{1, num_queries, num_heads, k_head_size}; - auto key_shape_bfyx = ov::PartialShape{1, num_keys, num_heads, k_head_size}; - auto value_shape_bfyx = ov::PartialShape{1, num_keys, num_heads, v_head_size}; - - auto query_layout = layout{query_shape_bfyx, data_types::f16, format::bfyx}; - auto key_layout = layout{key_shape_bfyx, data_types::f16, format::bfyx}; - auto value_layout = layout{value_shape_bfyx, data_types::f16, format::bfyx}; + auto query_layout = layout{{1, num_queries, num_heads, k_head_size}, data_types::f16, format::bfyx}; + auto key_layout = layout{{1, num_keys, num_heads, k_head_size}, data_types::f16, format::bfyx}; + auto value_layout = layout{{1, num_keys, num_heads, v_head_size}, data_types::f16, format::bfyx}; OPENVINO_ASSERT(query_layout.count() == query_data.size()); OPENVINO_ASSERT(key_layout.count() == key_data.size()); @@ -942,52 +964,47 @@ struct xAttentionReference { set_values(key_mem, key_data); set_values(value_mem, value_data); - std::vector query_data_3d(num_heads * num_queries * k_head_size); - std::vector key_data_3d(num_heads * num_keys * k_head_size); - - for (int h = 0; h < num_heads; h++) { - for (int q = 0; q < num_queries; q++) { - for (int d = 0; d < k_head_size; d++) { - query_data_3d[h * num_queries * k_head_size + q * k_head_size + d] = query_data[q * num_heads * k_head_size + h * k_head_size + d]; - } - } - } - - for (int h = 0; h < num_heads; h++) { - for (int k = 0; k < num_keys; k++) { - for (int d = 0; d < k_head_size; d++) { - key_data_3d[h * num_keys * k_head_size + k * k_head_size + d] = key_data[k * num_heads * k_head_size + h * k_head_size + d]; + auto reorder_qhk_to_hqd = [&](const std::vector& src, int outer_len, int num_heads, int head_dim) { + std::vector dst(num_heads * outer_len * head_dim); + for (int h = 0; h < num_heads; ++h) { + size_t dst_h_off = static_cast(h) * outer_len * head_dim; + for (int i = 0; i < outer_len; ++i) { + size_t src_off = static_cast(i) * num_heads * head_dim + static_cast(h) * head_dim; + std::copy_n(&src[src_off], head_dim, &dst[dst_h_off + static_cast(i) * head_dim]); } } - } + return dst; + }; - ov::Shape query_shape_3d = {static_cast(num_heads), static_cast(num_queries), static_cast(k_head_size)}; - ov::Shape key_shape_3d = {static_cast(num_heads), static_cast(num_keys), static_cast(k_head_size)}; + const auto query_data_3d = reorder_qhk_to_hqd(query_data, num_queries, num_heads, k_head_size); + const auto key_data_3d = reorder_qhk_to_hqd(key_data, num_keys, num_heads, k_head_size); CMXAttentionRetainedBlockIndicesForAllHeads retained_blocks; if (num_queries >= static_cast(block_size)) { - size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; - size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; - std::vector query_padded(num_heads * padded_q * k_head_size, ov::float16(0)); - std::vector key_padded(num_heads * padded_k * k_head_size, ov::float16(0)); - - for (int h = 0; h < num_heads; ++h) { - std::copy_n(&query_data_3d[h * num_queries * k_head_size], num_queries * k_head_size, &query_padded[h * padded_q * k_head_size]); - std::copy_n(&key_data_3d[h * num_keys * k_head_size], num_keys * k_head_size, &key_padded[h * padded_k * k_head_size]); - } + const size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; + const size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; - ov::Shape query_shape_padded = {static_cast(num_heads), padded_q, static_cast(k_head_size)}; - ov::Shape key_shape_padded = {static_cast(num_heads), padded_k, static_cast(k_head_size)}; + std::vector query_padded(num_heads * padded_q * k_head_size, 0.f); + std::vector key_padded(num_heads * padded_k * k_head_size, 0.f); - std::vector query_padded_f32(query_padded.size()); - std::vector key_padded_f32(key_padded.size()); - for (size_t i = 0; i < query_padded.size(); ++i) - query_padded_f32[i] = static_cast(query_padded[i]); - for (size_t i = 0; i < key_padded.size(); ++i) - key_padded_f32[i] = static_cast(key_padded[i]); + for (int h = 0; h < num_heads; ++h) { + const auto* q_src = &query_data_3d[h * num_queries * k_head_size]; + const auto* k_src = &key_data_3d[h * num_keys * k_head_size]; + auto* q_dst = &query_padded[h * padded_q * k_head_size]; + auto* k_dst = &key_padded[h * padded_k * k_head_size]; + std::transform(q_src, q_src + num_queries * k_head_size, q_dst, [](ov::float16 v) { + return static_cast(v); + }); + std::transform(k_src, k_src + num_keys * k_head_size, k_dst, [](ov::float16 v) { + return static_cast(v); + }); + } CMXAttentionBlockSelector selector(threshold, block_size, stride); - retained_blocks = selector.select_blocks(query_padded_f32.data(), query_shape_padded, key_padded_f32.data(), key_shape_padded); + retained_blocks = selector.select_blocks(query_padded.data(), + {static_cast(num_heads), padded_q, static_cast(k_head_size)}, + key_padded.data(), + {static_cast(num_heads), padded_k, static_cast(k_head_size)}); } auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, static_cast(block_size)); @@ -1174,7 +1191,7 @@ struct xAttentionReference { } } - XAttentionManager& pam; + XAttentionManager& xam; cldnn::engine& test_engine; cldnn::stream& test_stream; }; @@ -1191,7 +1208,7 @@ struct xAttentionTest : public ::testing::TestWithParam { } void execute(T& p) { - XAttentionManager pam(rg, + XAttentionManager xam(rg, get_test_engine(), get_test_stream(), p.subsequences, @@ -1203,40 +1220,41 @@ struct xAttentionTest : public ::testing::TestWithParam { p.kv_cache_compression, p.key_cache_quant_mode, p.scores_mode == XAttentionScoresMode::SNAPKV, - p.rotation_config); + p.rotation_config, + p.threshold); if (p.kv_cache_compression) tolerance = 25e-3; - auto query_mem = pam.get_query_memory(); - auto key_mem = pam.get_key_memory(); - auto value_mem = pam.get_value_memory(); + auto query_mem = xam.get_query_memory(); + auto key_mem = xam.get_key_memory(); + auto value_mem = xam.get_value_memory(); - auto key_cache_mem = pam.get_key_cache_memory(); - auto value_cache_mem = pam.get_value_cache_memory(); + auto key_cache_mem = xam.get_key_cache_memory(); + auto value_cache_mem = xam.get_value_cache_memory(); - auto past_lens_mem = pam.get_past_lens_memory(); - auto subsequence_begins_mem = pam.get_subsequence_begins_memory(); - auto block_indices_mem = pam.get_block_indices_memory(); - auto block_indices_begins_mem = pam.get_block_indices_begins_memory(); + auto past_lens_mem = xam.get_past_lens_memory(); + auto subsequence_begins_mem = xam.get_subsequence_begins_memory(); + auto block_indices_mem = xam.get_block_indices_memory(); + auto block_indices_begins_mem = xam.get_block_indices_begins_memory(); - auto scale_mem = pam.get_scale_memory(); - auto sliding_window_mem = pam.get_sliding_window_memory(); - auto alibi_mem = pam.get_alibi_memory(); - auto max_context_len_mem = pam.get_max_context_len_memory(); + auto scale_mem = xam.get_scale_memory(); + auto sliding_window_mem = xam.get_sliding_window_memory(); + auto alibi_mem = xam.get_alibi_memory(); + auto max_context_len_mem = xam.get_max_context_len_memory(); // scores calculation related memory buffers - auto score_aggregation_mem = pam.get_score_aggregation(); + auto score_aggregation_mem = xam.get_score_aggregation(); // cache rotation related memory buffers - auto rotated_block_indices_mem = pam.get_rotated_block_indices_memory(); - auto rotation_deltas_mem = pam.get_rotation_deltas_memory(); - auto rotation_trig_lut_mem = pam.get_rotation_trig_lut_memory(); + auto rotated_block_indices_mem = xam.get_rotated_block_indices_memory(); + auto rotation_deltas_mem = xam.get_rotation_deltas_memory(); + auto rotation_trig_lut_mem = xam.get_rotation_trig_lut_memory(); - auto xattention_threshold_mem = pam.get_xattention_threshold_memory(); - auto xattention_block_size_mem = pam.get_xattention_block_size_memory(); - auto xattention_stride_mem = pam.get_xattention_stride_memory(); - auto sinks_mem = pam.get_sinks_memory(); + auto xattention_threshold_mem = xam.get_xattention_threshold_memory(); + auto xattention_block_size_mem = xam.get_xattention_block_size_memory(); + auto xattention_stride_mem = xam.get_xattention_stride_memory(); + auto sinks_mem = xam.get_sinks_memory(); auto query_layout = query_mem->get_layout(); auto key_layout = key_mem->get_layout(); @@ -1336,7 +1354,7 @@ struct xAttentionTest : public ::testing::TestWithParam { pa_prim.v_head_size = p.v_head_size; pa_prim.kv_heads_num = p.num_heads; pa_prim.heads_num = p.num_heads; - pa_prim.scale_val = pam.get_default_scale(); + pa_prim.scale_val = xam.get_default_scale(); pa_prim.has_alibi = false; pa_prim.num_outputs = p.scores_mode == XAttentionScoresMode::DISABLED ? 1 : 2; pa_prim.has_rotated_blocks = p.rotation_config.apply_rotation; @@ -1420,7 +1438,7 @@ struct xAttentionTest : public ::testing::TestWithParam { if (p.scores_mode != XAttentionScoresMode::DISABLED) { output_scores_mem = outputs.at("output_scores").get_memory(); } - auto ref_data = xAttentionReference(pam).get_reference(); + auto ref_data = xAttentionReference(xam).get_reference(p.threshold); compare(output_data_mem, output_scores_mem, ref_data); } @@ -1449,6 +1467,16 @@ struct xAttentionTest : public ::testing::TestWithParam { EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.04)); } } + + static bool check_xattention_available() { + auto& engine = get_test_engine(); + ExecutionConfig config = get_test_default_config(engine); + if (!cldnn::check_cm_jit_support(engine, config) || !engine.get_device_info().supports_immad) { + return false; + } + + return true; + } }; struct xattention_test_params { @@ -1457,6 +1485,7 @@ struct xattention_test_params { int k_head_size; int v_head_size; int block_size; + std::vector threshold; int sliding_window_size; bool kv_cache_compression; ov::internal::CacheQuantMode key_cache_quant_mode; @@ -1468,6 +1497,8 @@ struct xattention_test_params { class xattention_test : public xAttentionTest {}; TEST_P(xattention_test, basic) { + if (!check_xattention_available()) + GTEST_SKIP(); auto p = GetParam(); execute(p); @@ -1491,15 +1522,16 @@ INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, ::testing::ValuesIn(std::vector{ /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ - xattention_test_params{ {{32, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - xattention_test_params{ {{4096, 0}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - - xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 1023}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 1024}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 127}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 128}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 129}}, 2, 64, 64, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 32}}, 28, 128, 128, 256, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{32, 0}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + xattention_test_params{ {{2048, 0}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + + xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 1023}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 1024}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 127}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 128}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 129}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 32}}, 28, 128, 128, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token })); From cca1528273e2ef6bad48fdd8d262c85cfe5da5d9 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Thu, 16 Oct 2025 17:05:11 +0800 Subject: [PATCH 75/96] bypas xattn when thresh>=1.0 and q_lenstage == PagedAttentionStage::PREFILL || rt_params->stage == PagedAttentionStage::MIXED) { - const float xattn_thresh = get_xattn_thresh(params); - const bool validate = xattn_thresh < 1.0; - if (has_stage(xattn_estimate_gemmqk) && validate) { // bypass xattn stages if threshold is larger than 1.0. + if (has_stage(xattn_estimate_gemmqk) && !bypass_xattn(params)) { res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 954b6db40e8b38..a415a1e412bd05 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -115,7 +115,7 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx) { // TODO: change xattn_thresh from scaler to memory... once we remove the converter node // between parameter node "xattention_threshold.xxx" and paged_attention node. -const float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_idx) { +float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_idx) { const auto& input_mem = params.memory_deps; const auto threshold_mem = input_mem.at(PagedAttentionInputIdx::XATTENTION_THRESHOLD); mem_lock lock(threshold_mem, *params.strm); // converted @@ -123,6 +123,18 @@ const float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_ return thresh; } + // Bypass xattn stages in the following conditions - + // either threshold is larger than 1.0, or, q_len is too small + // to compute xattn block_mask. +bool bypass_xattn(const kernel_impl_params& params) { + auto xattn_thresh = get_xattn_thresh(params); + bool bypass = xattn_thresh >= 1.0; + + auto q_len = params.output_layouts[0].get_shape()[0]; + bypass |= q_len < static_cast(STRIDE); //# will slient drop the tails which is less than `stride` + return bypass; +} + PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) { const auto& query_shape = impl_param.get_input_layout(PagedAttentionInputIdx::QUERY).get_partial_shape(); const auto& past_lens_shape = impl_param.get_input_layout(PagedAttentionInputIdx::PAST_LENS).get_partial_shape(); @@ -383,8 +395,7 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con scalars[2].v.s32 = static_cast(rtp->xattn_k_block_pad); scalars[3].t = ScalarDescriptor::Types::UINT8; - const float xattn_thresh = get_xattn_thresh(params); - const bool validate = xattn_thresh < 1.0; + const bool validate = !bypass_xattn(params); scalars[3].v.u8 = static_cast(validate); // validate depending on xattn_threshold } }}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 337c974dbe7156..2c81c58351484e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -64,8 +64,9 @@ size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); size_t get_partition_size(const bool has_xattention); size_t get_partition_num(const size_t kv_len, const bool has_xattention); -const float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx = 0); -inline const size_t get_xattn_block_size(const kernel_impl_params& impl_param) { +float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx = 0); +bool bypass_xattn(const kernel_impl_params& impl_param); +inline size_t get_xattn_block_size(const kernel_impl_params& impl_param) { return XATTN_BLOCK_SIZE; } From 618e575c62b8405de8bac338556beb438e5b6955 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Thu, 16 Oct 2025 17:12:27 +0800 Subject: [PATCH 76/96] throw exception if xattn is not supported by either GPU archieture or compiler. --- .../src/plugin/transformations_pipeline.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index a55c57df48204c..f75acd67f99902 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -519,15 +519,15 @@ void TransformationsPipeline::apply(std::shared_ptr func) { #ifdef GPU_DEBUG_CONFIG if (!config.get_use_cm()) { - OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," + OPENVINO_WARN("You may miss XAttention optimization," "as CM for usage is disabled. Enable it by setting environment variable OV_GPU_USE_CM=ON."); return false; } #endif if (!check_cm_jit_support(engine, config)) { - OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," - "as current IGC version is not compatible to the CM kernel used. Enable it by update IGC." + OPENVINO_WARN("You may miss XAttention optimization," + "as current IGC version is not compatible to the CM kernel used. Enable it by updating IGC." "Please also make sure clangFEWrapper for CM is present by checking environment varibles like " "CM_FE_DIR or LD_LIBRARY_PATH if you are using Linux."); return false; @@ -547,9 +547,10 @@ void TransformationsPipeline::apply(std::shared_ptr func) { } } - // Fallback to dense attention if xattn is not supported by either GPU archieture or compiler. - if (use_xattention) - use_xattention = check_xattn_gpu_compatibility(); + if (use_xattention) { + // Throw exception if xattn is not supported by either GPU archieture or compiler. + OPENVINO_ASSERT(check_xattn_gpu_compatibility(), "XAttention is not supported by either GPU archieture or IGC you are using."); + } // KVCache layout with default attention - // k: [num_blocks, num_kv_heads, head_size, block_size(16)] From 522a5031c423f6f1512bceb579847282d3b7be11 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 17 Oct 2025 13:59:26 +0800 Subject: [PATCH 77/96] add OV_GPU_DUMP_SRC_TENSORS_AFTER_EXEC --- .../intel_gpu/runtime/internal_properties.hpp | 1 + .../include/intel_gpu/runtime/options.inl | 1 + .../intel_gpu/src/graph/debug_helper.cpp | 46 ++++++++++--------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp index 42e6bf74949990..eb55e68ef6a609 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp @@ -176,6 +176,7 @@ static constexpr Property could_use_flashattn_ static constexpr Property dynamic_quantization_group_size_max{"GPU_DYNAMIC_QUANTIZATION_GROUP_SIZE_MAX"}; static constexpr Property validate_output_buffer{"GPU_VALIDATE_OUTPUT_BUFFER"}; static constexpr Property mem_pool_util_threshold{"GPU_MEM_POOL_UTIL_THRESHOLD"}; +static constexpr Property dump_src_after_exec{"GPU_DUMP_SRC_TENSORS_AFTER_EXEC"}; } // namespace ov::intel_gpu namespace cldnn { diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl index 2126e21c509a6e..be2bf896326e62 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl @@ -81,6 +81,7 @@ OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_layer_names, std::vector OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_memory_pool_path, "", "Save csv file with memory pool info to specified folder") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_memory_pool, false, "Enable verbose output for memory pool") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_iterations, std::set{}, "Space separated list of iterations where other dump options should be enabled") +OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_src_after_exec, false, "Enable source data dump after layer execution. Useful for capturing updated states in stateful models.") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, host_time_profiling, 0, "Measure and print host time spent from the beginning of the infer until all host work is done and plugin is ready to block thread on the final clFinish() call") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_async_compilation, false, "Disable feature that allows to asynchronously prepare static-shaped implementations for the primitives with shape-agnostic kernels selected during compilation") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_runtime_buffer_fusing, false, "Disable runtime inplace optimizations for operations like concat and crop") diff --git a/src/plugins/intel_gpu/src/graph/debug_helper.cpp b/src/plugins/intel_gpu/src/graph/debug_helper.cpp index c619c2894bfa99..145c1d5b8065a4 100644 --- a/src/plugins/intel_gpu/src/graph/debug_helper.cpp +++ b/src/plugins/intel_gpu/src/graph/debug_helper.cpp @@ -559,29 +559,31 @@ NodeDebugHelper::~NodeDebugHelper() { } } - for (size_t i = 0; i < m_inst.inputs_memory_count(); i++) { - std::string name = get_file_prefix() + "_updated_src_" + std::to_string(i); - auto output_mem = m_inst.input_memory_ptr(i); - if (output_mem == nullptr) { - GPU_DEBUG_COUT << " updated_input_mem is nullptr. Nothing to dump." << std::endl; - continue; - } - - auto& output_layout = m_inst.get_input_layout(i); - if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary) { - // Binary dump : raw - auto filename = get_file_path_for_binary_dump(output_layout, name, config.get_dump_tensors_path()); + if (config.get_dump_src_after_exec()) { + for (size_t i = 0; i < m_inst.inputs_memory_count(); i++) { + std::string name = get_file_prefix() + "_updated_src_" + std::to_string(i); + auto output_mem = m_inst.input_memory_ptr(i); + if (output_mem == nullptr) { + GPU_DEBUG_COUT << " updated_input_mem is nullptr. Nothing to dump." << std::endl; + continue; + } - mem_lock lock(output_mem, m_stream); - ov::util::save_binary(filename, lock.data(), output_mem->size()); - GPU_DEBUG_COUT << " Dump layer dst : " << layer_name << " to " << filename << std::endl; - debug_str_for_bin_load += (filename + ","); - } else { - const bool dump_raw = config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::text_raw; - GPU_DEBUG_COUT << " Dump " << (dump_raw ? "raw " : "") << name << std::endl; - auto filename = config.get_dump_tensors_path() + get_name_for_dump(name) + ".txt"; - // Text dump - log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw); + auto& output_layout = m_inst.get_input_layout(i); + if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary) { + // Binary dump : raw + auto filename = get_file_path_for_binary_dump(output_layout, name, config.get_dump_tensors_path()); + + mem_lock lock(output_mem, m_stream); + ov::util::save_binary(filename, lock.data(), output_mem->size()); + GPU_DEBUG_COUT << " Dump layer dst : " << layer_name << " to " << filename << std::endl; + debug_str_for_bin_load += (filename + ","); + } else { + const bool dump_raw = config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::text_raw; + GPU_DEBUG_COUT << " Dump " << (dump_raw ? "raw " : "") << name << std::endl; + auto filename = config.get_dump_tensors_path() + get_name_for_dump(name) + ".txt"; + // Text dump + log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw); + } } } From 3e527be3bcde29b3830f71447c39ff759df2c915 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 17 Oct 2025 14:01:04 +0800 Subject: [PATCH 78/96] code cleanup, unused code --- .../intel_gpu/src/graph/impls/cm/include/sort.hpp | 11 ----------- .../intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm | 6 ------ src/plugins/intel_gpu/src/runtime/memory.cpp | 6 +++--- .../unit/test_cases/paged_attention_gpu_test.cpp | 3 --- 4 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp index f05396fac810d9..97f680e5cba441 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp @@ -144,17 +144,6 @@ CM_INLINE void sort(uint slm, svmptr_t src, svmptr_t sorted_value, svmptr_t sort cm_slm_write(slm, addr, total); } } - // { - // // prefix sum - // vector data; - // cm_slm_block_read(slm, 0, data); - // for (int i = 1; i < 16 * THREADS; i++) { - // data[i] += data[i - 1]; - // } - // data.select<16 * THREADS - 1, 1>(1) = data.select<16 * THREADS - 1, 1>(0); - // data[0] = 0; - // cm_slm_block_write(slm, 0, data); - // } { // prefix sum vector local_prefix = 0; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm index 4e1e305944983f..e3f23fd9ec65e2 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm @@ -84,12 +84,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint id_wg_m, id_wg_n; get_mn(id_wg_m, id_wg_n, M, N, slice_no, slice, BLOCK_WG_M, BLOCK_WG_N); - auto wg_id_N = cm_group_id(0); - auto wg_lid_N = cm_local_id(0); - auto wg_id_M = cm_group_id(1); - auto wg_lid_M = cm_local_id(1); - // printf("=============================================== wgN:%d.%d wgM:%d.%d hq %d =============================================== \n", wg_id_N, wg_lid_N, wg_id_M, wg_lid_M, hq); - // key cache: [block, HQ, KV_BLOCK_SIZE, HEAD_SIZE_KEY] #if USE_INT8 key_cache += hk * (KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); diff --git a/src/plugins/intel_gpu/src/runtime/memory.cpp b/src/plugins/intel_gpu/src/runtime/memory.cpp index c43526fd6be184..f69a3124da7d6d 100644 --- a/src/plugins/intel_gpu/src/runtime/memory.cpp +++ b/src/plugins/intel_gpu/src/runtime/memory.cpp @@ -35,9 +35,9 @@ MemoryTracker::~MemoryTracker() { try { m_engine->subtract_memory_used(m_buffer_size, m_alloc_type); } catch (...) {} - // GPU_DEBUG_TRACE_DETAIL << "Free " << m_buffer_size << " bytes of " << m_alloc_type << " allocation type ptr = " << m_buffer_ptr - // << " (current=" << m_engine->get_used_device_memory(m_alloc_type) << ";" - // << " max=" << m_engine->get_max_used_device_memory(m_alloc_type) << ")" << std::endl; + GPU_DEBUG_TRACE_DETAIL << "Free " << m_buffer_size << " bytes of " << m_alloc_type << " allocation type ptr = " << m_buffer_ptr + << " (current=" << m_engine->get_used_device_memory(m_alloc_type) << ";" + << " max=" << m_engine->get_max_used_device_memory(m_alloc_type) << ")" << std::endl; } } diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 362afd4ff95c3a..f01d8010969005 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1264,9 +1264,6 @@ struct PagedAttentionTest : public ::testing::TestWithParam { if (data_output_mem) { ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); mem_lock mem_ptr(data_output_mem, get_test_stream()); - // for (size_t i = 0; i < data_output_mem->count(); i++) { - // std::cout << i << ": result = " << mem_ptr[i] << ", reference = " << ref_data.first[i] << std::endl; - // } for (size_t i = 0; i < data_output_mem->count(); i++) { ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance) << " at index=" << i; } From 1e243fc0a96fc2fb7283c296579d77308749372c Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 17 Oct 2025 14:01:27 +0800 Subject: [PATCH 79/96] throw exception for unsupported cases. --- .../src/graph/impls/cm/paged_attention.cpp | 1 - .../graph/impls/cm/paged_attention_gen.cpp | 4 +-- .../graph/impls/cm/paged_attention_gen.hpp | 6 ++-- .../src/plugin/transformations_pipeline.cpp | 34 +++++++++++-------- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 1e5142fdd3512b..50ec344e598ccd 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -197,7 +197,6 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto count_elements_mask = static_cast(desc->heads_num * q_block_pad * k_block_pad); internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask - const uint32_t MERGED_Q_NUM = 2; // TODO const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); auto count_elements_mask_merged = static_cast(desc->heads_num * q_block_pad_merged * k_block_pad); internal_buffers.emplace_back(count_elements_mask_merged, ov::element::boolean); // 5: sparse_block_mask_wg diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index a415a1e412bd05..ea45c3e61a77e6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -825,7 +825,7 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { JitConstants XAttentionEstimatePostProc::get_jit_constants(const kernel_impl_params& params) const { auto jit = XAttentionEstimateGeneratorBase::get_jit_constants(params); - jit.make("MERGED_Q_NUM", 2); // TODO + jit.make("MERGED_Q_NUM", MERGED_Q_NUM); return jit; } @@ -871,8 +871,6 @@ DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); const uint32_t k_block_pad = k_block_in_group * N_kq_groups; const uint32_t q_block_pad = ceil_div(q_len, block_size); - - const uint32_t MERGED_Q_NUM = 2; // TODO const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); wgs.global = {q_block_pad_merged, heads_num, 1}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 2c81c58351484e..3581fc9d5b03a7 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -23,9 +23,6 @@ using namespace cldnn; // TODO: Remove once namespaces are aligned namespace ov::intel_gpu::cm { -// constexpr auto get_pa_build_options() { -// return " -cmc -Qxcm_register_file_size=256 -mdump_asm -g2 "; -// } constexpr auto get_pa_build_options() { return " -cmc -Qxcm_register_file_size=256"; } @@ -41,7 +38,8 @@ constexpr uint32_t SG_N = 8; constexpr uint32_t BLOCK_WG_M = BLOCK_SG_M * SG_M; constexpr uint32_t BLOCK_WG_N = BLOCK_SG_N * SG_N; constexpr int STRIDE = 16; -constexpr size_t XATTN_BLOCK_SIZE = 128; +constexpr uint32_t XATTN_BLOCK_SIZE = 128; +constexpr uint32_t MERGED_Q_NUM = PA_KV_CACHE_BLOCK_SIZE_XATTN / XATTN_BLOCK_SIZE; // for xattn post_proc enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; struct PagedAttentionRuntimeParams : public ImplRuntimeParams { diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index ed56389f0135c3..29d3d9105ba801 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -365,17 +365,15 @@ void TransformationsPipeline::apply(std::shared_ptr func) { #ifdef GPU_DEBUG_CONFIG if (!config.get_use_cm()) { - OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," - "as CM for usage is disabled. Enable it by setting environment variable OV_GPU_USE_CM=ON."); + OPENVINO_WARN("XAttention optimization is disabled because CM is not enabled. " + "To enable, set environment variable OV_GPU_USE_CM=ON."); return true; } #endif if (!check_cm_jit_support(engine, config)) { - OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," - "as current IGC version is not compatible to the CM kernel used. Enable it by update IGC." - "Please also make sure clangFEWrapper for CM is present by checking environment varibles like " - "CM_FE_DIR or LD_LIBRARY_PATH if you are using Linux."); + OPENVINO_WARN("SDPAToVLSDPA optimization for QWenVL model unavailable: IGC version incompatible with CM kernel. " + "Update IGC and ensure clangFEWrapper for CM is available (check CM_FE_DIR or LD_LIBRARY_PATH on Linux)."); return true; } @@ -519,25 +517,25 @@ void TransformationsPipeline::apply(std::shared_ptr func) { #ifdef GPU_DEBUG_CONFIG if (!config.get_use_cm()) { - OPENVINO_WARN("You may miss XAttention optimization," - "as CM for usage is disabled. Enable it by setting environment variable OV_GPU_USE_CM=ON."); + OPENVINO_WARN("XAttention optimization is disabled because CM is not enabled. " + "To enable, set environment variable OV_GPU_USE_CM=ON."); return false; } #endif if (!check_cm_jit_support(engine, config)) { - OPENVINO_WARN("You may miss XAttention optimization," - "as current IGC version is not compatible to the CM kernel used. Enable it by updating IGC." - "Please also make sure clangFEWrapper for CM is present by checking environment varibles like " - "CM_FE_DIR or LD_LIBRARY_PATH if you are using Linux."); + OPENVINO_WARN("XAttention optimization unavailable: IGC version incompatible with CM kernel. " + "Update IGC and ensure clangFEWrapper for CM is available (check CM_FE_DIR or LD_LIBRARY_PATH on Linux)."); return false; } return true; }; - // Determine if XAttention is enabled by user (via GENAI) by checking if model parameters contains - // xattention configurations, which are added in SDPAToPagedAttention pass. + + // Check if XAttention is enabled by the user via GENAI. + // This is determined by inspecting the model parameters for XAttention configurations, + // which are introduced during the SDPAToPagedAttention pass. bool use_xattention = false; const auto& parameters = func->get_parameters(); for (const auto& param : parameters) { @@ -549,7 +547,10 @@ void TransformationsPipeline::apply(std::shared_ptr func) { if (use_xattention) { // Throw exception if xattn is not supported by either GPU archieture or compiler. - OPENVINO_ASSERT(check_xattn_gpu_compatibility(), "XAttention is not supported by either GPU archieture or IGC you are using."); + if (!check_xattn_gpu_compatibility()) + OPENVINO_THROW("XAttention is not supported by your current GPU architecture or IGC version. " + "Please either disable XAttention by following the GenAI guide, or switch to a GPU with Xe2/Xe3 " + "architecture and ensure the latest IGC is installed."); } // KVCache layout with default attention - @@ -574,6 +575,9 @@ void TransformationsPipeline::apply(std::shared_ptr func) { kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; if (use_xattention) { + if (kv_cache_config.keyCacheQuantBychannel) + OPENVINO_THROW("XAttention does not currently support per-channel quantized key cache."); + kv_cache_config.valueCacheBlockSize = cldnn::paged_attention::block_size_xattn; kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; } else { From b45062c2e225c3c565da714a50b37f89a7ae73ef Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 17 Oct 2025 14:13:05 +0800 Subject: [PATCH 80/96] fix dump... intermediates tensor may empty. --- src/plugins/intel_gpu/src/graph/debug_helper.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/src/graph/debug_helper.cpp b/src/plugins/intel_gpu/src/graph/debug_helper.cpp index 145c1d5b8065a4..008fae45309ff8 100644 --- a/src/plugins/intel_gpu/src/graph/debug_helper.cpp +++ b/src/plugins/intel_gpu/src/graph/debug_helper.cpp @@ -536,7 +536,7 @@ NodeDebugHelper::~NodeDebugHelper() { for (size_t i = 0; i < m_inst.get_intermediates_memories().size(); i++) { std::string name = get_file_prefix() + "_intermediates_" + std::to_string(i); auto output_mem = m_inst.get_intermediates_memories()[i]; - if (output_mem == nullptr) { + if (output_mem == nullptr || output_mem->size() == 0) { GPU_DEBUG_COUT << " intermediates_mem is nullptr. Nothing to dump." << std::endl; continue; } From 50628c555090e4501cdf39a2eb83c452c3bc0ed5 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 17 Oct 2025 14:47:40 +0800 Subject: [PATCH 81/96] fix --- .../intel_gpu/src/plugin/transformations_pipeline.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 29d3d9105ba801..fcfb210d35cc48 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -548,7 +548,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { if (use_xattention) { // Throw exception if xattn is not supported by either GPU archieture or compiler. if (!check_xattn_gpu_compatibility()) - OPENVINO_THROW("XAttention is not supported by your current GPU architecture or IGC version. " + OPENVINO_THROW("[GPU] XAttention is not supported by your current GPU architecture or IGC version. " "Please either disable XAttention by following the GenAI guide, or switch to a GPU with Xe2/Xe3 " "architecture and ensure the latest IGC is installed."); } @@ -575,8 +575,9 @@ void TransformationsPipeline::apply(std::shared_ptr func) { kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; if (use_xattention) { - if (kv_cache_config.keyCacheQuantBychannel) - OPENVINO_THROW("XAttention does not currently support per-channel quantized key cache."); + if (kv_cache_config.keyCacheQuantBychannel && + ((kv_cache_precision == ov::element::i8 || kv_cache_precision == ov::element::u8)) ) + OPENVINO_THROW("[GPU] XAttention does not currently support per-channel quantized key cache."); kv_cache_config.valueCacheBlockSize = cldnn::paged_attention::block_size_xattn; kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; From 10730028452ef19fe236d9da3700e903d7a5752e Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Sun, 19 Oct 2025 22:15:13 +0800 Subject: [PATCH 82/96] Ww/pa cm xattention 1019 (#61) * Tests support num_kv_heads * Update test cases * Fix code style * Fix code style --- .../impls/ocl_v2/sdpa/paged_attention_opt.hpp | 2 +- .../test_cases/paged_attention_gpu_test.cpp | 29 +- .../unit/test_cases/xattention_gpu_test.cpp | 558 +++++------------- 3 files changed, 160 insertions(+), 429 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp index f0e80007d7f60c..d3c3c92f6f5a77 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp @@ -8,9 +8,9 @@ #include #include +#include "paged_attention_inst.h" #include "program_node.h" #include "registry/implementation_manager.hpp" -#include "paged_attention_inst.h" using namespace cldnn; // TODO: Remove once namespaces are aligned diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index f01d8010969005..ec3f1a5f421ba4 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -236,8 +236,8 @@ struct PagedAttentionManager { } auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, adjusted_head_size }; - auto key_cache_layout = layout{ key_cache_shape, key_cache_dt, format::bfyx }; + auto key_cache_shape = ov::PartialShape{num_blocks, num_kv_heads, block_size, adjusted_head_size}; + auto key_cache_layout = layout{key_cache_shape, key_cache_dt, format::bfyx}; auto memory = test_engine.allocate_memory(key_cache_layout); for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { @@ -246,33 +246,28 @@ struct PagedAttentionManager { int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size - : block_size; + int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) : block_size; for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + - input_token_offset * num_heads * v_head_size + - head_idx * v_head_size; + ov::float16* data_ptr = key_data[i].data() + input_token_offset * num_kv_heads * v_head_size + head_idx * v_head_size; if (kv_cache_compression) { auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); auto quantized_data_ptr = quantized_data.data(); - // shape: [num_blocks, num_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; + // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] + size_t output_block_offset = + (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + token_idx * v_head_size; set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); } else { - // shape: [num_blocks, num_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; + // shape: [num_blocks, num_kv_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + token_idx * v_head_size; set_values(test_stream, memory, data_ptr, v_head_size, output_offset); } diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp index 1a189dc6447798..2eae9c3908cd20 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp @@ -10,12 +10,10 @@ #include #include #include +#include #include +#include -#include "openvino/reference/divide.hpp" -#include "openvino/reference/matmul.hpp" -#include "openvino/reference/softmax.hpp" -#include "openvino/reference/transpose.hpp" #include "openvino/runtime/tensor.hpp" #include "primitive_inst.h" #include "random_generator.hpp" @@ -42,6 +40,7 @@ struct XAttentionCacheRotationDescriptor { struct XAttentionManager { int num_heads; + int num_kv_heads; int k_head_size; int v_head_size; int block_size; @@ -88,6 +87,7 @@ struct XAttentionManager { cldnn::stream& stream, const std::vector& subsequence_descs, int num_heads, + int num_kv_heads, int k_head_size, int v_head_size, int block_size, @@ -98,6 +98,7 @@ struct XAttentionManager { XAttentionCacheRotationDescriptor rotation_config, std::vector threshold) : num_heads(num_heads), + num_kv_heads(num_kv_heads), k_head_size(k_head_size), v_head_size(v_head_size), block_size(block_size), @@ -123,8 +124,8 @@ struct XAttentionManager { max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); query_data.push_back(generate_realistic_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); - key_data.push_back(generate_realistic_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); - value_data.push_back(generate_realistic_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); + key_data.push_back(generate_realistic_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); + value_data.push_back(generate_realistic_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); past_lens.push_back(subsequence_desc.past_len); int subsequence_start_pos = subsequence_begins[i]; @@ -176,15 +177,15 @@ struct XAttentionManager { } memory::ptr get_query_memory() { - return get_QKV_memory(query_data, k_head_size, false); + return get_QKV_memory(query_data, num_heads, k_head_size, false); } memory::ptr get_key_memory() { - return get_QKV_memory(key_data, k_head_size, true); + return get_QKV_memory(key_data, num_kv_heads, k_head_size, true); } memory::ptr get_value_memory() { - return get_QKV_memory(value_data, v_head_size, true); + return get_QKV_memory(value_data, num_kv_heads, v_head_size, true); } memory::ptr get_key_cache_memory() { @@ -196,7 +197,7 @@ struct XAttentionManager { } auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{num_blocks, num_heads, block_size, adjusted_head_size}; + auto key_cache_shape = ov::PartialShape{num_blocks, num_kv_heads, block_size, adjusted_head_size}; auto key_cache_layout = layout{key_cache_shape, key_cache_dt, format::bfyx}; auto memory = test_engine.allocate_memory(key_cache_layout); @@ -206,18 +207,18 @@ struct XAttentionManager { int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size : block_size; + int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) : block_size; for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + input_token_offset * num_heads * v_head_size + head_idx * v_head_size; + ov::float16* data_ptr = key_data[i].data() + input_token_offset * num_kv_heads * v_head_size + head_idx * v_head_size; if (kv_cache_compression) { auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); auto quantized_data_ptr = quantized_data.data(); - // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] size_t output_block_offset = - (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; + (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; size_t output_offset = output_block_offset + token_idx * v_head_size; set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); @@ -225,8 +226,8 @@ struct XAttentionManager { set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); } else { - // shape: [num_blocks, num_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + // shape: [num_blocks, num_kv_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + head_idx * block_size * v_head_size + token_idx * v_head_size; set_values(test_stream, memory, data_ptr, v_head_size, output_offset); @@ -249,7 +250,7 @@ struct XAttentionManager { } auto num_blocks = block_indices.back() + 1; - auto value_cache_shape = ov::PartialShape{num_blocks, num_heads, block_size, adjusted_head_size}; + auto value_cache_shape = ov::PartialShape{num_blocks, num_kv_heads, block_size, adjusted_head_size}; auto value_cache_layout = layout{value_cache_shape, value_cache_dt, format::bfyx}; auto memory = test_engine.allocate_memory(value_cache_layout); @@ -259,18 +260,18 @@ struct XAttentionManager { int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size : block_size; + int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) : block_size; for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = value_data[i].data() + input_token_offset * num_heads * v_head_size + head_idx * v_head_size; + ov::float16* data_ptr = value_data[i].data() + input_token_offset * num_kv_heads * v_head_size + head_idx * v_head_size; if (kv_cache_compression) { auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); auto quantized_data_ptr = quantized_data.data(); - // shape: [num_blocks, num_heads, block_size, adjusted_head_size] + // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] size_t output_block_offset = - (start_block_idx + block_idx) * num_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; + (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; size_t output_offset = output_block_offset + token_idx * v_head_size; set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); @@ -278,8 +279,8 @@ struct XAttentionManager { set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); } else { - // shape: [num_blocks, num_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * v_head_size + + // shape: [num_blocks, num_kv_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + head_idx * block_size * v_head_size + token_idx * v_head_size; set_values(test_stream, memory, data_ptr, v_head_size, output_offset); @@ -417,12 +418,12 @@ struct XAttentionManager { return memory; } - memory::ptr get_QKV_memory(std::vector>& input_data, int k_head_size, bool skip_past_len) { + memory::ptr get_QKV_memory(std::vector>& input_data, int num_heads, int head_size, bool skip_past_len) { int total_tokens = 0; for (const auto& subsequence_desc : subsequence_descs) total_tokens += subsequence_desc.num_tokens; - auto query_shape = ov::PartialShape{total_tokens, num_heads * k_head_size}; + auto query_shape = ov::PartialShape{total_tokens, num_heads * head_size}; auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; auto memory = test_engine.allocate_memory(query_layout); @@ -434,12 +435,12 @@ struct XAttentionManager { if (skip_past_len) input_token_offset += subsequence_descs[subsequence_idx].past_len; - ov::float16* data_ptr = input_data[subsequence_idx].data() + input_token_offset * num_heads * k_head_size + head_idx * k_head_size; + ov::float16* data_ptr = input_data[subsequence_idx].data() + input_token_offset * num_heads * head_size + head_idx * head_size; size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; - size_t output_offset = output_token_offset * num_heads * k_head_size + head_idx * k_head_size; + size_t output_offset = output_token_offset * num_heads * head_size + head_idx * head_size; - set_values(test_stream, memory, data_ptr, k_head_size, output_offset); + set_values(test_stream, memory, data_ptr, head_size, output_offset); } } } @@ -544,335 +545,6 @@ struct XAttentionManager { } }; -using Shape = std::vector; - -using CMXAttentionBlockIndex = std::pair; // .first is the *query* dimension block index, .second is *key* -using CMXAttentionRetainedBlockIndices = std::set; -using CMXAttentionRetainedBlockIndicesForAllHeads = std::vector; - -template -class CMXAttentionBlockSelector { -public: - CMXAttentionBlockSelector(double threshold, size_t block_size, size_t stride) : m_threshold(threshold), m_block_size(block_size), m_stride(stride) { - OPENVINO_ASSERT(m_block_size % m_stride == 0); - } - - void diagonal_reshape(const T* input_data, const Shape& input_shape, T* output_data, const Shape& output_shape, bool is_antidiagonal) { - OPENVINO_ASSERT(input_shape.size() == 3); - OPENVINO_ASSERT(output_shape.size() == 3); - size_t H = input_shape[0]; - size_t Q_orig = input_shape[1]; - size_t D = input_shape[2]; - size_t Q_new = output_shape[1]; - - OPENVINO_ASSERT(Q_orig % m_stride == 0); - OPENVINO_ASSERT(Q_orig / m_stride == Q_new); - - for (size_t h = 0; h < H; ++h) { - size_t head_in_offset = h * Q_orig * D; - size_t head_out_offset = h * Q_new * m_stride * D; - - for (size_t s = 0; s < m_stride; ++s) { - for (size_t q = 0; q < Q_new; ++q) { - size_t in_idx; - if (is_antidiagonal) { - // Anti-diagonal: (stride - 1 - s + q * stride) - in_idx = head_in_offset + (m_stride - 1 - s + q * m_stride) * D; - } else { - // Normal diagonal: (s + q * stride) - in_idx = head_in_offset + (s + q * m_stride) * D; - } - - size_t out_idx = head_out_offset + q * m_stride * D + s * D; - std::memcpy(output_data + out_idx, input_data + in_idx, D * sizeof(T)); - } - } - } - } - - void transpose_matmul_scale(const T* reshaped_query_data, - const T* reshaped_key_data, - const Shape& reshaped_query_shape, - const Shape& reshaped_key_shape, - T* out, - const Shape& out_shape) { - OPENVINO_ASSERT(reshaped_key_shape.size() == 3); - OPENVINO_ASSERT(reshaped_query_shape.size() == 3); - OPENVINO_ASSERT(reshaped_query_shape[0] == reshaped_key_shape[0]); - OPENVINO_ASSERT(reshaped_query_shape[2] == reshaped_key_shape[2]); - - OPENVINO_ASSERT(out_shape.size() == 3); - OPENVINO_ASSERT(out_shape[0] == reshaped_query_shape[0]); - OPENVINO_ASSERT(out_shape[1] == reshaped_query_shape[1]); - OPENVINO_ASSERT(out_shape[2] == reshaped_key_shape[1]); - - ov::reference::matmul(reshaped_query_data, reshaped_key_data, out, reshaped_query_shape, reshaped_key_shape, out_shape, false, true); - - size_t out_size = out_shape[0] * out_shape[1] * out_shape[2]; - - for (size_t i = 0; i < out_size; i++) { - out[i] = out[i] / std::sqrt(reshaped_query_shape[2] * m_stride); - } - } - - void softmax(const T* reshaped_qk_product_data, const Shape& reshaped_qk_product_shape, T* out, const Shape& out_shape) { - OPENVINO_ASSERT(reshaped_qk_product_shape.size() == 3); - OPENVINO_ASSERT(reshaped_qk_product_shape == out_shape); - ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2}); - } - - void block_sum_attention_scores(const T* attention_scores_data, const Shape& attention_scores_shape, T* out, const Shape& out_shape) { - OPENVINO_ASSERT(attention_scores_shape.size() == 3); - size_t antidiagonals_per_xattention_block = m_block_size / m_stride; - OPENVINO_ASSERT(attention_scores_shape[1] % antidiagonals_per_xattention_block == 0); - OPENVINO_ASSERT(attention_scores_shape[2] % antidiagonals_per_xattention_block == 0); - - OPENVINO_ASSERT(out_shape[0] == attention_scores_shape[0]); - OPENVINO_ASSERT(out_shape[1] == attention_scores_shape[1] / antidiagonals_per_xattention_block); - OPENVINO_ASSERT(out_shape[2] == attention_scores_shape[2] / antidiagonals_per_xattention_block); - - std::memset(out, 0, out_shape[0] * out_shape[1] * out_shape[2] * sizeof(T)); - - for (size_t head_idx = 0; head_idx < attention_scores_shape[0]; head_idx++) { - size_t in_head_offset = head_idx * attention_scores_shape[1] * attention_scores_shape[2]; - size_t out_head_offset = head_idx * out_shape[1] * out_shape[2]; - for (size_t query_len_idx = 0; query_len_idx < attention_scores_shape[1]; query_len_idx++) { - for (size_t key_len_idx = 0; key_len_idx < attention_scores_shape[2]; key_len_idx++) { - size_t query_block_idx = query_len_idx / antidiagonals_per_xattention_block; - size_t key_block_idx = key_len_idx / antidiagonals_per_xattention_block; - auto target_block_sum_ptr = out + out_head_offset + query_block_idx * out_shape[2] + key_block_idx; - *target_block_sum_ptr += *(attention_scores_data + in_head_offset + query_len_idx * attention_scores_shape[2] + key_len_idx); - } - } - } - } - - CMXAttentionRetainedBlockIndicesForAllHeads get_block_indices_to_keep(T* blocked_attention_scores_data, const Shape& blocked_attention_scores_shape) { - OPENVINO_ASSERT(blocked_attention_scores_shape.size() == 3, "Expected shape [num_heads, q_block_num, k_block_num]"); - - size_t num_heads = blocked_attention_scores_shape[0]; - size_t q_block_num = blocked_attention_scores_shape[1]; - size_t k_block_num = blocked_attention_scores_shape[2]; - - CMXAttentionRetainedBlockIndicesForAllHeads retval(num_heads); - - std::vector>> mask(num_heads, std::vector>(q_block_num, std::vector(k_block_num, false))); - - for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { - for (size_t q_block_idx = 0; q_block_idx < q_block_num; q_block_idx++) { - size_t diagonal_k = q_block_idx; - if (diagonal_k < k_block_num) { - mask[head_idx][q_block_idx][diagonal_k] = true; - } - // Step1: First column reserved - mask[head_idx][q_block_idx][0] = true; - - // Step2: Create other_values(masked_fill) - std::vector> other_values; - for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - if (mask[head_idx][q_block_idx][k_block_idx]) - continue; - size_t offset = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + k_block_idx; - other_values.emplace_back(static_cast(blocked_attention_scores_data[offset]), k_block_idx); - } - - // Step3: Sort other-values in descending order - std::sort(other_values.begin(), other_values.end(), [](const auto& a, const auto& b) { - return a.first > b.first; - }); - - // Step4: Create cumulative_sum_without_self,cat([0, diagonal_sum, sorted_values[:-1]]) - std::vector sorted_scores; - sorted_scores.push_back(0.0); - // diagonal + First column score - size_t offset_diag = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + diagonal_k; - float diag_score = static_cast(blocked_attention_scores_data[offset_diag]); - float first_col_score = 0.0; - if (diagonal_k != 0) { - size_t offset_first = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num + 0; - first_col_score = static_cast(blocked_attention_scores_data[offset_first]); - } - sorted_scores.push_back(diag_score + first_col_score); - - for (auto& p : other_values) { - sorted_scores.push_back(p.first); - } - if (q_block_idx == 0) { - sorted_scores.pop_back(); - } - - // Step5: Calculate cumsum_without_self: cumsum of right-shifted sorted_scores - std::vector cumsum_without_self(sorted_scores.size(), 0.0); - float running = 0.0; - for (size_t i = 0; i < sorted_scores.size(); ++i) { - cumsum_without_self[i] = running; - running += sorted_scores[i]; - } - - // Step6: Generate required_sum - size_t offset_row_start = head_idx * q_block_num * k_block_num + q_block_idx * k_block_num; - float row_sum = 0.0; - for (size_t k = 0; k < k_block_num; k++) { - row_sum += static_cast(blocked_attention_scores_data[offset_row_start + k]); - } - float required_sum = row_sum * m_threshold; - - // Step7: Create index_mask - std::vector index_mask(cumsum_without_self.size(), false); - for (size_t i = 0; i < cumsum_without_self.size(); i++) { - index_mask[i] = (cumsum_without_self[i] < required_sum); - } - - // Step8: Create index - std::vector index(index_mask.size(), 0); - for (size_t i = 0; i < index_mask.size(); i++) { - if (index_mask[i]) { - if (i == 0) - index[i] = 0; - else if (i == 1) - index[i] = diagonal_k; - else if (i - 2 < other_values.size()) - index[i] = other_values[i - 2].second; - else - index[i] = 0; - } - } - - // Step9: Get retval - for (size_t i = 0; i < index.size(); i++) { - size_t k_block_idx = index[i]; - if (index_mask[i] && k_block_idx < k_block_num) { - mask[head_idx][q_block_idx][k_block_idx] = true; - } - } - - for (size_t k_block_idx = 0; k_block_idx < k_block_num; k_block_idx++) { - if (mask[head_idx][q_block_idx][k_block_idx]) - retval[head_idx].insert({q_block_idx, k_block_idx}); - } - } - } - - return retval; - } - - CMXAttentionRetainedBlockIndicesForAllHeads select_blocks(const T* query_data, - const Shape& query_shape, - const T* key_data, - const Shape& key_shape, - int chunk_size = -1) { - OPENVINO_ASSERT(query_shape.size() == 3 && key_shape.size() == 3); - OPENVINO_ASSERT(query_shape[0] == key_shape[0] && query_shape[2] == key_shape[2]); - OPENVINO_ASSERT(query_shape[1] % m_stride == 0 && key_shape[1] % m_stride == 0); - OPENVINO_ASSERT(query_shape[1] % m_block_size == 0 && key_shape[1] % m_block_size == 0); - - const size_t num_heads = query_shape[0]; - const size_t q_len = query_shape[1]; - const size_t k_len = key_shape[1]; - const size_t head_dim = query_shape[2]; - if (chunk_size == -1) - chunk_size = static_cast(q_len); - - auto pad_seq = [&](const T* src_data, size_t seq_len) { - size_t num_to_pad = ((seq_len + chunk_size - 1) / chunk_size) * chunk_size - seq_len; - Shape pad_shape = {num_heads, seq_len + num_to_pad, head_dim}; - auto buf = allocate_buf(pad_shape); - - for (size_t h = 0; h < num_heads; ++h) { - size_t src_off = h * seq_len * head_dim; - size_t dst_off = h * (seq_len + num_to_pad) * head_dim; - std::memcpy(buf.get() + dst_off, src_data + src_off, seq_len * head_dim * sizeof(T)); - if (num_to_pad) - std::fill(buf.get() + dst_off + seq_len * head_dim, buf.get() + dst_off + (seq_len + num_to_pad) * head_dim, T(0)); - } - return std::make_pair(std::move(buf), pad_shape); - }; - - // ======== Pad Query & Key ======== - auto [pad_query_buf, pad_query_shape] = pad_seq(query_data, q_len); - auto [pad_key_buf, pad_key_shape] = pad_seq(key_data, k_len); - - // ======== Diagonal Reshape ======== - const size_t reshaped_q_len = pad_query_shape[1] / m_stride; - const size_t reshaped_k_len = pad_key_shape[1] / m_stride; - Shape q_shape_r = {num_heads, reshaped_q_len, head_dim * m_stride}; - Shape k_shape_r = {num_heads, reshaped_k_len, head_dim * m_stride}; - - auto q_buf = allocate_buf(q_shape_r); - auto k_buf = allocate_buf(k_shape_r); - diagonal_reshape(pad_query_buf.get(), pad_query_shape, q_buf.get(), q_shape_r, true); - diagonal_reshape(pad_key_buf.get(), pad_key_shape, k_buf.get(), k_shape_r, false); - pad_query_buf.reset(); - pad_key_buf.reset(); - - // ======== QK^T + scale ======== - Shape qk_shape = {num_heads, reshaped_q_len, reshaped_k_len}; - auto qk_buf = allocate_buf(qk_shape); - transpose_matmul_scale(q_buf.get(), k_buf.get(), q_shape_r, k_shape_r, qk_buf.get(), qk_shape); - q_buf.reset(); - k_buf.reset(); - - // ======== Causal Mask ======== - auto causal_mask_buf = allocate_buf(qk_shape); - std::fill(causal_mask_buf.get(), causal_mask_buf.get() + ov::shape_size(qk_shape), T(0)); - const size_t reshaped_chunk_size = q_len / m_stride; - const size_t k_chunk_num = (k_len + ((k_len + chunk_size - 1) / chunk_size * chunk_size - k_len)) / q_len; - const size_t k_reshaped_seq_len = pad_key_shape[1] / m_stride; - const size_t k_reshaped_num_to_pad = pad_key_shape[1] / m_stride - k_len / m_stride; - const size_t chunk_start = (k_chunk_num - 1) * reshaped_chunk_size; - const size_t chunk_end = chunk_start + reshaped_chunk_size; - const T neg_inf = std::numeric_limits::lowest(); - - for (size_t h = 0; h < num_heads; ++h) { - for (size_t q = 0; q < reshaped_chunk_size; ++q) { - size_t base = h * reshaped_chunk_size * (reshaped_chunk_size * k_chunk_num) + q * (reshaped_chunk_size * k_chunk_num); - - for (size_t k = k_reshaped_seq_len - k_reshaped_num_to_pad; k < k_reshaped_seq_len; ++k) - causal_mask_buf.get()[base + k] = neg_inf; - for (size_t k = q + 1; k < reshaped_chunk_size; ++k) - causal_mask_buf.get()[base + chunk_start + k] = neg_inf; - for (size_t k = chunk_end; k < reshaped_chunk_size * k_chunk_num; ++k) - causal_mask_buf.get()[base + k] = neg_inf; - } - } - // ======== qk += mask ======== - for (size_t i = 0; i < ov::shape_size(qk_shape); ++i) - qk_buf.get()[i] += causal_mask_buf.get()[i]; - causal_mask_buf.reset(); - - // ======== softmax ======== - auto attn_score_buf = allocate_buf(qk_shape); - softmax(qk_buf.get(), qk_shape, attn_score_buf.get(), qk_shape); - qk_buf.reset(); - - // ======== block sum + select ======== - const size_t blocks_per_axis = m_block_size / m_stride; - Shape block_sum_shape = {num_heads, reshaped_q_len / blocks_per_axis, reshaped_k_len / blocks_per_axis}; - auto block_sum_buf = allocate_buf(block_sum_shape); - block_sum_attention_scores(attn_score_buf.get(), qk_shape, block_sum_buf.get(), block_sum_shape); - attn_score_buf.reset(); - - auto selected_block_indices = get_block_indices_to_keep(block_sum_buf.get(), block_sum_shape); - block_sum_buf.reset(); - - return selected_block_indices; - } - - std::shared_ptr allocate_buf(const Shape& shape) { - return std::shared_ptr(new T[ov::shape_size(shape)]); - } - - size_t pad_to_block(size_t token_length) { - return (token_length + m_block_size - 1) / m_block_size * m_block_size; - } - - double m_threshold; - - size_t m_block_size; - - size_t m_stride; -}; - struct xAttentionReference { xAttentionReference(XAttentionManager& xam) : xam(xam), test_engine(xam.test_engine), test_stream(xam.test_stream) {} @@ -902,7 +574,7 @@ struct xAttentionReference { xam.rotation_trig_lut, index, subsequence_rotated_block_idx, - xam.num_heads, + xam.num_kv_heads, xam.k_head_size, xam.block_size, xam.rotation_config.per_block); @@ -918,6 +590,7 @@ struct xAttentionReference { subsequence_desc.num_tokens, kv_seq_len, xam.num_heads, + xam.num_kv_heads, xam.k_head_size, xam.v_head_size, window_size, @@ -940,6 +613,7 @@ struct xAttentionReference { int num_queries, int num_keys, int num_heads, + int num_kv_heads, int k_head_size, int v_head_size, int window_size, @@ -948,9 +622,18 @@ struct xAttentionReference { double threshold = 0.9, size_t block_size = 128, size_t stride = 16) { - auto query_layout = layout{{1, num_queries, num_heads, k_head_size}, data_types::f16, format::bfyx}; - auto key_layout = layout{{1, num_keys, num_heads, k_head_size}, data_types::f16, format::bfyx}; - auto value_layout = layout{{1, num_keys, num_heads, v_head_size}, data_types::f16, format::bfyx}; + auto query_shape = ov::PartialShape{1, num_queries, num_heads, k_head_size}; + auto key_shape = ov::PartialShape{1, num_keys, num_kv_heads, k_head_size}; + auto value_shape = ov::PartialShape{1, num_keys, num_kv_heads, v_head_size}; + if (num_heads != num_kv_heads) { + query_shape = ov::PartialShape{num_queries, num_kv_heads, (num_heads / num_kv_heads), k_head_size}; + key_shape = ov::PartialShape{num_keys, num_kv_heads, 1, k_head_size}; + value_shape = ov::PartialShape{num_keys, num_kv_heads, 1, v_head_size}; + } + auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; + auto key_layout = layout{key_shape, data_types::f16, format::bfyx}; + auto value_layout = layout{value_shape, data_types::f16, format::bfyx}; + auto scale_layout = cldnn::layout({1}, data_types::f16, format::bfyx); OPENVINO_ASSERT(query_layout.count() == query_data.size()); OPENVINO_ASSERT(key_layout.count() == key_data.size()); @@ -959,10 +642,12 @@ struct xAttentionReference { auto query_mem = test_engine.allocate_memory(query_layout); auto key_mem = test_engine.allocate_memory(key_layout); auto value_mem = test_engine.allocate_memory(value_layout); + auto scale_mem = test_engine.allocate_memory(scale_layout); set_values(query_mem, query_data); set_values(key_mem, key_data); set_values(value_mem, value_data); + set_values(scale_mem, {static_cast(scale)}); auto reorder_qhk_to_hqd = [&](const std::vector& src, int outer_len, int num_heads, int head_dim) { std::vector dst(num_heads * outer_len * head_dim); @@ -979,7 +664,7 @@ struct xAttentionReference { const auto query_data_3d = reorder_qhk_to_hqd(query_data, num_queries, num_heads, k_head_size); const auto key_data_3d = reorder_qhk_to_hqd(key_data, num_keys, num_heads, k_head_size); - CMXAttentionRetainedBlockIndicesForAllHeads retained_blocks; + ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; if (num_queries >= static_cast(block_size)) { const size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; const size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; @@ -1000,29 +685,56 @@ struct xAttentionReference { return static_cast(v); }); } - CMXAttentionBlockSelector selector(threshold, block_size, stride); + ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); retained_blocks = selector.select_blocks(query_padded.data(), {static_cast(num_heads), padded_q, static_cast(k_head_size)}, key_padded.data(), {static_cast(num_heads), padded_k, static_cast(k_head_size)}); } - auto mask_mem = get_mask_mem_combined_multi_head(num_queries, num_keys, num_heads, sliding_window_size, retained_blocks, static_cast(block_size)); - + auto mask_mem = get_mask_mem_combined_multi_head(num_queries, + num_keys, + num_heads, + num_kv_heads, + sliding_window_size, + retained_blocks, + static_cast(block_size)); topology topology; - topology.add(input_layout("query", query_layout), - input_layout("key", key_layout), - input_layout("value", value_layout), - data("mask", mask_mem), - permute("query_transposed", input_info("query"), {0, 2, 1, 3}), - permute("key_transposed", input_info("key"), {0, 2, 1, 3}), - permute("value_transposed", input_info("value"), {0, 2, 1, 3}), - gemm("qk_gemm", {input_info("query_transposed"), input_info("key_transposed")}, data_types::f16, false, true, scale), - eltwise("eltwise", {input_info("qk_gemm"), input_info("mask")}, eltwise_mode::sum), - softmax("softmax", input_info("eltwise"), -1), - gemm("qkv_gemm", {input_info("softmax"), input_info("value_transposed")}, data_types::f16, false, false), - permute("qkv_gemm_transposed", input_info("qkv_gemm"), {0, 2, 1, 3}), - reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16), - reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16)); + if (num_heads == num_kv_heads) { + topology.add(input_layout("query", query_layout), + input_layout("key", key_layout), + input_layout("value", value_layout), + data("mask", mask_mem), + data("scale", scale_mem), + permute("query_transposed", input_info("query"), {0, 2, 1, 3}), + permute("key_transposed", input_info("key"), {0, 2, 3, 1}), + permute("value_transposed", input_info("value"), {0, 2, 1, 3}), + gemm("qk_gemm", {input_info("query_transposed"), input_info("key_transposed")}, data_types::f16, false, false), + eltwise("scale_div", {input_info("qk_gemm"), input_info("scale")}, eltwise_mode::prod), + eltwise("eltwise", {input_info("scale_div"), input_info("mask")}, eltwise_mode::sum), + softmax("softmax", input_info("eltwise"), -1), + gemm("qkv_gemm", {input_info("softmax"), input_info("value_transposed")}, data_types::f16, false, false), + permute("qkv_gemm_transposed", input_info("qkv_gemm"), {0, 2, 1, 3}), + reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16), + reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16)); + } else { + topology.add(input_layout("query", query_layout), + input_layout("key", key_layout), + input_layout("value", value_layout), + data("mask", mask_mem), + data("scale", scale_mem), + permute("query_transposed", input_info("query"), {1, 2, 0, 3}), + permute("key_transposed", input_info("key"), {1, 2, 3, 0}), + permute("value_transposed", input_info("value"), {1, 2, 0, 3}), + gemm("qk_gemm", {input_info("query_transposed"), input_info("key_transposed")}, data_types::f16, false, false), + eltwise("scale_div", {input_info("qk_gemm"), input_info("scale")}, eltwise_mode::prod), + eltwise("eltwise", {input_info("scale_div"), input_info("mask")}, eltwise_mode::sum), + softmax("softmax", input_info("eltwise"), -1), + gemm("qkv_gemm", {input_info("softmax"), input_info("value_transposed")}, data_types::f16, false, false), + reshape("qkv_gemm_reshape", input_info("qkv_gemm"), {1, num_heads, v_head_size, num_queries}), + permute("qkv_gemm_transposed", input_info("qkv_gemm_reshape"), {0, 2, 1, 3}), + reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16), + reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16)); + } ExecutionConfig config = get_test_default_config(test_engine); config.set_property(ov::intel_gpu::optimize_data(true)); @@ -1073,16 +785,37 @@ struct xAttentionReference { memory::ptr get_mask_mem_combined_multi_head(int num_queries, int num_keys, int num_heads, + int num_kv_heads, int sliding_window_size, - const CMXAttentionRetainedBlockIndicesForAllHeads& retained_blocks, + const ov::reference::XAttentionRetainedBlockIndicesForAllHeads& retained_blocks, int block_size) { - auto mask_shape = ov::PartialShape{1, num_heads, num_queries, num_keys}; + OPENVINO_ASSERT(num_kv_heads > 0, "num_kv_heads must be > 0"); + OPENVINO_ASSERT(num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); + + int heads_per_kv = num_heads / num_kv_heads; + + ov::PartialShape mask_shape; + if (num_heads == num_kv_heads) { + mask_shape = ov::PartialShape{1, num_heads, num_queries, num_keys}; + } else { + mask_shape = ov::PartialShape{num_kv_heads, heads_per_kv, num_queries, num_keys}; + } + auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; auto mask_mem = test_engine.allocate_memory(mask_layout); - mem_lock mem_ptr(mask_mem, test_stream); - for (int h = 0; h < num_heads; h++) { + size_t total_elems = mask_layout.count(); + for (size_t i = 0; i < total_elems; ++i) + mem_ptr[i] = std::numeric_limits::lowest(); + + for (int h = 0; h < num_heads; ++h) { + int kv_idx = (num_heads == num_kv_heads) ? 0 : (h / heads_per_kv); + int head_in_kv = (num_heads == num_kv_heads) ? h : (h % heads_per_kv); + + size_t head_offset = (static_cast(kv_idx) * heads_per_kv + static_cast(head_in_kv)) * static_cast(num_queries) * + static_cast(num_keys); + if (retained_blocks.empty() || retained_blocks[h].empty()) { for (int i = 0; i < num_queries; i++) { for (int j = 0; j < num_keys; j++) { @@ -1103,18 +836,12 @@ struct xAttentionReference { if (is_min) value = std::numeric_limits::lowest(); } - mem_ptr[h * num_queries * num_keys + i * num_keys + j] = value; + mem_ptr[head_offset + i * num_keys + j] = value; } } continue; } - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - mem_ptr[h * num_queries * num_keys + i * num_keys + j] = std::numeric_limits::lowest(); - } - } - for (int i = 0; i < num_queries; i++) { int left_idx = 0; int right_idx = 0; @@ -1149,7 +876,7 @@ struct xAttentionReference { for (int j = k_start; j < k_end; j++) { if (j >= left_idx && j <= right_idx) { - mem_ptr[h * num_queries * num_keys + i * num_keys + j] = ov::float16(0.f); + mem_ptr[head_offset + i * num_keys + j] = ov::float16(0.f); } } } @@ -1213,6 +940,7 @@ struct xAttentionTest : public ::testing::TestWithParam { get_test_stream(), p.subsequences, p.num_heads, + p.num_kv_heads, p.k_head_size, p.v_head_size, p.block_size, @@ -1280,10 +1008,18 @@ struct xAttentionTest : public ::testing::TestWithParam { // make layouts dynamic query_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads * p.k_head_size}); - key_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads * p.k_head_size}); - value_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads * p.v_head_size}); - key_cache_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads, p.block_size, p.k_head_size}); - value_cache_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads, p.block_size, p.v_head_size}); + key_layout.set_partial_shape(ov::PartialShape{-1, p.num_kv_heads * p.k_head_size}); + value_layout.set_partial_shape(ov::PartialShape{-1, p.num_kv_heads * p.v_head_size}); + { + auto pshape = key_cache_layout.get_partial_shape(); + pshape[0] = -1; + key_cache_layout.set_partial_shape(pshape); + } + { + auto pshape = value_cache_layout.get_partial_shape(); + pshape[0] = -1; + value_cache_layout.set_partial_shape(pshape); + } past_lens_layout.set_partial_shape(ov::PartialShape{-1}); subsequence_begins_layout.set_partial_shape(ov::PartialShape{-1}); block_indices_layout.set_partial_shape(ov::PartialShape{-1}); @@ -1352,7 +1088,7 @@ struct xAttentionTest : public ::testing::TestWithParam { pa_prim.k_head_size = p.k_head_size; pa_prim.v_head_size = p.v_head_size; - pa_prim.kv_heads_num = p.num_heads; + pa_prim.kv_heads_num = p.num_kv_heads; pa_prim.heads_num = p.num_heads; pa_prim.scale_val = xam.get_default_scale(); pa_prim.has_alibi = false; @@ -1362,7 +1098,6 @@ struct xAttentionTest : public ::testing::TestWithParam { pa_prim.sliding_window = p.sliding_window_size; pa_prim.is_key_by_channel = (p.key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL); pa_prim.has_xattention = true; - topology topology; topology.add(input_layout("query", query_layout), @@ -1482,6 +1217,7 @@ struct xAttentionTest : public ::testing::TestWithParam { struct xattention_test_params { std::vector subsequences; int num_heads; + int num_kv_heads; int k_head_size; int v_head_size; int block_size; @@ -1522,16 +1258,16 @@ INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, ::testing::ValuesIn(std::vector{ /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, token_size>=32, disable_mix_mode */ - xattention_test_params{ {{32, 0}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - xattention_test_params{ {{1024, 0}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - xattention_test_params{ {{2048, 0}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - - xattention_test_params{ {{1, 31}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 32}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 1023}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 1024}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 127}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 128}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 129}}, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 32}}, 28, 128, 128, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{32, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + xattention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + xattention_test_params{ {{2048, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + + xattention_test_params{ {{1, 31}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 32}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 1023}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 1024}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 127}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 128}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 129}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + xattention_test_params{ {{1, 32}}, 28, 28, 128, 128, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token })); From 5eff824f4fbb28006db22c40e1db4fbe98b5f9b9 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Mon, 20 Oct 2025 00:43:25 +0800 Subject: [PATCH 83/96] Ww/pa cm xattention 1020 (#62) * Fix code style * Clean code --- .../src/graph/impls/cm/paged_attention_gen.cpp | 6 +++--- .../intel_gpu/tests/common/random_generator.hpp | 10 ---------- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 3 --- 3 files changed, 3 insertions(+), 16 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index ea45c3e61a77e6..ed3840f14621d2 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -123,9 +123,9 @@ float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_idx) { return thresh; } - // Bypass xattn stages in the following conditions - - // either threshold is larger than 1.0, or, q_len is too small - // to compute xattn block_mask. +// Bypass xattn stages in the following conditions - +// either threshold is larger than 1.0, or, q_len is too small +// to compute xattn block_mask. bool bypass_xattn(const kernel_impl_params& params) { auto xattn_thresh = get_xattn_thresh(params); bool bypass = xattn_thresh >= 1.0; diff --git a/src/plugins/intel_gpu/tests/common/random_generator.hpp b/src/plugins/intel_gpu/tests/common/random_generator.hpp index 92f4b32591cb6d..8dfb4a616b1c6d 100644 --- a/src/plugins/intel_gpu/tests/common/random_generator.hpp +++ b/src/plugins/intel_gpu/tests/common/random_generator.hpp @@ -57,16 +57,6 @@ class random_generator { return v; } - template - std::vector generate_random_1d_fixed(size_t a, int start, int step, int k = 100) { - std::vector v(a); - - for (size_t i = 0; i < a; ++i) { - v[i] = static_cast(start + i * step) / k; - } - return v; - } - template std::vector> generate_random_2d(size_t a, size_t b, int min, int max, int k = 8) { std::vector> v(a); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index ec3f1a5f421ba4..c31a8c0fd69056 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -608,9 +608,6 @@ struct PagedAttentionManager { const size_t total_elements_num = tokens_num * num_heads * k_head_size; auto data = rg.generate_random_1d(total_elements_num, -1, 1); - // test code - // auto data = rg.generate_random_1d_fixed(total_elements_num, 0, 1, 10000); - return data; } From 853b56276b1b6bb0da988137caad5b53633c0546 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 17 Oct 2025 15:38:38 +0800 Subject: [PATCH 84/96] PagedAttentionInternBuffIdx --- .../src/graph/impls/cm/paged_attention.cpp | 20 +++++++------- .../graph/impls/cm/paged_attention_gen.cpp | 26 +++++++++---------- .../graph/impls/cm/paged_attention_gen.hpp | 13 ++++++++-- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 50ec344e598ccd..d489647cd21182 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -173,18 +173,18 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { internal_buffers.emplace_back(16, indexes_dt); // 1: softmax exp_sums // internal buffer for XAttention - auto out_shape = params.output_layouts[0].get_shape(); - const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; - const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` - const uint32_t N = static_cast(kv_len / STRIDE); - const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + if (desc->has_xattention) { + auto out_shape = params.output_layouts[0].get_shape(); + const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; + const size_t q_len = out_shape[0]; + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t N = static_cast(kv_len / STRIDE); + const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); + const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); - auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); - internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg + auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); + internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg - if (desc->has_xattention) { const size_t block_size = get_xattn_block_size(params); OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); const uint32_t q_block_pad = ceil_div(q_len, block_size); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index ed3840f14621d2..f8bc98cba480a3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -311,8 +311,8 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); if (desc->has_xattention) { - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // sparse_block_mask_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK_MERGED}); // sparse_block_mask_wg } args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len @@ -449,8 +449,8 @@ Arguments PagedAttentionGeneratorSingleToken::get_arguments_desc(const kernel_im args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence begins // outputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); // partition output - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); // lse output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_PARTITIONOUT}); // partition output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse output // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len==1 @@ -514,9 +514,9 @@ Arguments PagedAttentionGeneratorSingleTokenFinalization::get_arguments_desc(con const auto has_scores_output = params.output_layouts.size() > 1; OPENVINO_ASSERT(!has_scores_output, "[GPU][CM] PagedAttentionGeneratorSingleTokenFinalization with scores output is not supported yet"); - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); // partition data + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_PARTITIONOUT}); // partition data args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // output - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); // lse + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // kv_partition_num @@ -614,8 +614,8 @@ Arguments XAttentionEstimateGEMMQK::get_arguments_desc(const kernel_impl_params& args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins // outputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); // kq_max_wg - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // kq_exp_partial_sum + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_EXPSUMS}); // kq_exp_partial_sum // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // M @@ -735,11 +735,11 @@ Arguments XAttentionEstimateFindBlock::get_arguments_desc(const kernel_impl_para Arguments args; // inputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); // kq_max_wg - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // kq_exp_partial_sum + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_EXPSUMS}); // kq_exp_partial_sum // outputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // block_mask + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len @@ -834,10 +834,10 @@ Arguments XAttentionEstimatePostProc::get_arguments_desc(const kernel_impl_param Arguments args; // inputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // block_mask + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask // outputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // block_mask_merged + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK_MERGED}); // sparse_block_mask_wg // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_stride_pad diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 3581fc9d5b03a7..4498c70980a51c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -52,6 +52,17 @@ struct PagedAttentionRuntimeParams : public ImplRuntimeParams { size_t xattn_k_block_pad; }; +enum PagedAttentionInternBuffIdx { + // for decoding kernels + DECODE_PARTITIONOUT = 0, // 0: intermediate partition output + DECODE_EXPSUMS = 1, // 1: softmax exp_sums + // for xattn kernels + XATTN_GEMMQK_MAX = 2, // 2: kq_max_wg + XATTN_GEMMQK_EXPSUMS = 3, // 3: kq_exp_partial_sum + XATTN_BLOCKMASK = 4, // 4: sparse_block_mask + XATTN_BLOCKMASK_MERGED = 5, // 5: sparse_block_mask_wg +}; + //----------------------------------------------------------------------------------------------------------------- // Helpers of XAttention //----------------------------------------------------------------------------------------------------------------- @@ -124,7 +135,6 @@ class XAttentionEstimateGeneratorBase : public KernelGenerator { class XAttentionEstimateGEMMQK : public XAttentionEstimateGeneratorBase { public: XAttentionEstimateGEMMQK() : XAttentionEstimateGeneratorBase("xattn_gemm_qk") {} - // [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; }; @@ -132,7 +142,6 @@ class XAttentionEstimateGEMMQK : public XAttentionEstimateGeneratorBase { class XAttentionEstimateFindBlock : public XAttentionEstimateGeneratorBase { public: XAttentionEstimateFindBlock() : XAttentionEstimateGeneratorBase("xattn_find_block") {} - // [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; }; From 0870cbb13b52e8d70b6db8df3b9598136d302c16 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 17 Oct 2025 15:55:30 +0800 Subject: [PATCH 85/96] refactor xattention kernel impls by reusing RT parameters, instead of recomputing them. --- .../src/graph/impls/cm/paged_attention.cpp | 108 ++++++----- .../graph/impls/cm/paged_attention_gen.cpp | 179 +++++------------- .../graph/impls/cm/paged_attention_gen.hpp | 17 +- 3 files changed, 113 insertions(+), 191 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index d489647cd21182..fb49c6cc556da1 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -54,6 +54,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { void update_xattn_rt_params(const primitive_inst& instance) { const auto& params = *instance.get_impl_params(); + OPENVINO_ASSERT(!params.is_dynamic()); + const auto desc = params.typed_desc(); auto out_shape = params.output_layouts[0].get_shape(); const size_t block_size = get_xattn_block_size(params); @@ -68,8 +70,24 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { const auto k_block_pad = k_block_in_group * N_kq_groups; auto rt_params = static_cast(m_rt_params.get()); - rt_params->xattn_q_block_pad = q_block_pad; - rt_params->xattn_k_block_pad = k_block_pad; + rt_params->q_block_pad = q_block_pad; + rt_params->k_block_pad = k_block_pad; + + // XAttention estimate is following afer kvcache_update. + const size_t head_size = desc->k_head_size; + + auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; + + const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` + const uint32_t K = static_cast(STRIDE * head_size); + + const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); + + rt_params->N_kq_groups = N_kq_groups; + rt_params->M = M; + rt_params->N = N; + rt_params->K = K; + rt_params->q_stride_pad = q_stride_pad; } void update_rt_params(const primitive_inst& instance) override { @@ -77,22 +95,30 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { if (m_rt_params == nullptr) { m_rt_params = std::make_unique(); } - GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::update_rt_params()" << std::endl; + const auto& params = *instance.get_impl_params(); auto rt_params = static_cast(m_rt_params.get()); const auto& desc = params.typed_desc(); + rt_params->stage = get_paged_attention_stage(params); const auto max_context_len = get_max_context_len(params); rt_params->max_context_len = max_context_len; - rt_params->partition_size = get_partition_size(desc->has_xattention); - rt_params->num_of_partitions = ceil_div(max_context_len, rt_params->partition_size); - rt_params->stage = get_paged_attention_stage(params); - if (desc->has_xattention) { - update_xattn_rt_params(instance); + GPU_DEBUG_TRACE_DETAIL << "update_rt_params for stage: " << static_cast(rt_params->stage) + << " max_context_len: " << rt_params->max_context_len << std::endl; + + if (rt_params->stage == PagedAttentionStage::GENERATE) { + auto partition_size = get_partition_size(desc->has_xattention); + rt_params->num_of_partitions = ceil_div(max_context_len, partition_size); + + GPU_DEBUG_TRACE_DETAIL << " partition_size: " << partition_size + << " num_of_partitions: " << rt_params->num_of_partitions << std::endl; + } else { + if (desc->has_xattention) { + update_xattn_rt_params(instance); + } } - GPU_DEBUG_TRACE_DETAIL << " max_context_len: " << rt_params->max_context_len << " partition_size: " << rt_params->partition_size - << " num_of_partitions: " << rt_params->num_of_partitions << ", stage: " << static_cast(rt_params->stage) << std::endl; + } // update impl_parameter and rt_parameter @@ -107,7 +133,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { update_stages_flags(instance); auto rt_params = static_cast(m_rt_params.get()); - assert(rt_params != nullptr); + OPENVINO_ASSERT(rt_params != nullptr); GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::execute(): stage = " << static_cast(rt_params->stage) << std::endl; std::vector res_event = events; @@ -140,25 +166,16 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { std::vector internal_buffers; const auto desc = params.typed_desc(); - const auto indexes_dt = ov::element::f32; - auto stage = PagedAttentionStage::UNKNOWN; auto rt_params = static_cast(m_rt_params.get()); + // Assume rt_params are updated, because get_internal_buffer_descs surely occurs after update_rt_params. + OPENVINO_ASSERT(rt_params != nullptr); - size_t partition_size = PA_KV_CACHE_BLOCK_SIZE; - size_t num_of_partitions = 1; - if (rt_params != nullptr && rt_params->num_of_partitions != 0) { - stage = rt_params->stage; - partition_size = rt_params->partition_size; - num_of_partitions = rt_params->num_of_partitions; - } else { - stage = get_paged_attention_stage(params); - const auto max_context_len = get_max_context_len(params); - partition_size = get_partition_size(desc->has_xattention); - num_of_partitions = ceil_div(max_context_len, partition_size); - } - GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::get_internal_buffer_descs(): stage = " << static_cast(stage) - << " partition_size: " << partition_size << " num_of_partitions: " << num_of_partitions << std::endl; + const auto stage = rt_params->stage; + GPU_DEBUG_TRACE_DETAIL << " stage = " << static_cast(stage) << std::endl; if (stage == PagedAttentionStage::GENERATE) { + OPENVINO_ASSERT(rt_params->num_of_partitions != 0); + size_t num_of_partitions = rt_params->num_of_partitions; + const auto& input = params.input_layouts[0]; const int64_t total_tokens = input.get_partial_shape()[0].get_length(); auto buf_elements_count = static_cast(total_tokens * desc->heads_num * num_of_partitions); @@ -167,39 +184,30 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { internal_buffers.emplace_back(tmp_out_elements_count, ov::element::f32); // 0: intermediate partition output internal_buffers.emplace_back(buf_elements_count, ov::element::f32); // 1: softmax exp_sums - GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 4 << " exp_sums=" << buf_elements_count * 4 << std::endl; + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 4 + << " exp_sums=" << buf_elements_count * 4 << std::endl; } else { - internal_buffers.emplace_back(16, indexes_dt); // 0: intermediate partition output - internal_buffers.emplace_back(16, indexes_dt); // 1: softmax exp_sums + internal_buffers.emplace_back(16, ov::element::f32); // 0: intermediate partition output + internal_buffers.emplace_back(16, ov::element::f32); // 1: softmax exp_sums // internal buffer for XAttention if (desc->has_xattention) { - auto out_shape = params.output_layouts[0].get_shape(); - const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; - const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` - const uint32_t N = static_cast(kv_len / STRIDE); - const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); - - auto count_kq_max_wg = static_cast(desc->heads_num * N_kq_groups * q_stride_pad); + auto count_kq_max_wg = static_cast(desc->heads_num * rt_params->N_kq_groups * rt_params->q_stride_pad); internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg - const size_t block_size = get_xattn_block_size(params); - OPENVINO_ASSERT(block_size % STRIDE == 0, "sparse block_size must be devidable by stride."); - const uint32_t q_block_pad = ceil_div(q_len, block_size); - const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); - const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); - const uint32_t k_block_pad = k_block_in_group * N_kq_groups; - auto count_kq_exp_partial_sum = static_cast(desc->heads_num * q_stride_pad * k_block_pad); + auto count_kq_exp_partial_sum = static_cast(desc->heads_num * rt_params->q_stride_pad * rt_params->k_block_pad); internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum - auto count_elements_mask = static_cast(desc->heads_num * q_block_pad * k_block_pad); + auto count_elements_mask = static_cast(desc->heads_num * rt_params->q_block_pad * rt_params->k_block_pad); internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask - const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); - auto count_elements_mask_merged = static_cast(desc->heads_num * q_block_pad_merged * k_block_pad); + auto count_elements_mask_merged = static_cast(desc->heads_num * rt_params->q_block_pad_merged * rt_params->k_block_pad); internal_buffers.emplace_back(count_elements_mask_merged, ov::element::boolean); // 5: sparse_block_mask_wg + + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: count_kq_max_wg=" << count_kq_max_wg * 4 + << " count_kq_exp_partial_sum=" << count_kq_exp_partial_sum * 4 + << " count_elements_mask=" << count_elements_mask * 1 + << " count_elements_mask_merged=" << count_kq_exp_partial_sum * 1 << std::endl; } } @@ -212,7 +220,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { }; std::unique_ptr PagedAttentionImplementationManager::create_impl(const program_node& node, const kernel_impl_params& params) const { - assert(node.is_type()); + OPENVINO_ASSERT(node.is_type()); try { return std::make_unique(params); } catch (const std::exception& e) { diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index f8bc98cba480a3..fa0c58f7992986 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -89,13 +89,6 @@ size_t get_partition_size(const bool has_xattention) { } } -size_t get_partition_num(const size_t kv_len, const bool has_xattention) { - const size_t partition_size = get_partition_size(has_xattention); - const size_t partition_num = (kv_len + partition_size - 1) / partition_size; - - return partition_num; -} - // max_context_len = max(past_lens + prompt_lens) size_t get_max_context_len(const kernel_impl_params& params) { const auto& input_mem = params.memory_deps; @@ -228,7 +221,7 @@ Arguments PagedAttentionGeneratorKVCacheUpdate::get_arguments_desc(const kernel_ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() const { return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { - assert(!params.is_dynamic()); + OPENVINO_ASSERT(!params.is_dynamic()); auto& wgs = kd.params.workGroups; const auto desc = params.typed_desc(); @@ -360,7 +353,7 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con auto& scalars = kd.params.scalars; auto desc = params.typed_desc(); auto rtp = static_cast(rt_params); - // assert(rt_params != nullptr); + // OPENVINO_ASSERT(rt_params != nullptr); const size_t heads_num = desc->heads_num; auto query_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; @@ -389,10 +382,10 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con scalars[0].v.s32 = static_cast(q_len); if (num_scalers > 1) { scalars[1].t = ScalarDescriptor::Types::INT32; - scalars[1].v.s32 = static_cast(rtp->xattn_q_block_pad); + scalars[1].v.s32 = static_cast(rtp->q_block_pad); scalars[2].t = ScalarDescriptor::Types::INT32; - scalars[2].v.s32 = static_cast(rtp->xattn_k_block_pad); + scalars[2].v.s32 = static_cast(rtp->k_block_pad); scalars[3].t = ScalarDescriptor::Types::UINT8; const bool validate = !bypass_xattn(params); @@ -460,16 +453,16 @@ Arguments PagedAttentionGeneratorSingleToken::get_arguments_desc(const kernel_im DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() const { return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { - assert(!params.is_dynamic()); + OPENVINO_ASSERT(!params.is_dynamic()); auto& wgs = kd.params.workGroups; const auto desc = params.typed_desc(); auto rtp = static_cast(rt_params); - assert(rt_params != nullptr); + OPENVINO_ASSERT(rt_params != nullptr); const size_t batch = params.input_layouts[0].get_partial_shape()[0].get_length(); const size_t heads_num = desc->heads_num; const size_t kv_heads_num = desc->kv_heads_num; - const size_t partition_num = rtp->num_of_partitions; // get_partition_num(rtp->max_context_len); + const size_t partition_num = rtp->num_of_partitions; wgs.global = {batch, kv_heads_num, partition_num}; wgs.local = {1, 1, 1}; @@ -512,7 +505,8 @@ JitConstants PagedAttentionGeneratorSingleTokenFinalization::get_jit_constants(c Arguments PagedAttentionGeneratorSingleTokenFinalization::get_arguments_desc(const kernel_impl_params& params) const { Arguments args; const auto has_scores_output = params.output_layouts.size() > 1; - OPENVINO_ASSERT(!has_scores_output, "[GPU][CM] PagedAttentionGeneratorSingleTokenFinalization with scores output is not supported yet"); + if (has_scores_output) + OPENVINO_THROW("[GPU][CM] PagedAttentionGeneratorSingleTokenFinalization with scores output is not supported yet"); args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_PARTITIONOUT}); // partition data args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // output @@ -526,13 +520,13 @@ Arguments PagedAttentionGeneratorSingleTokenFinalization::get_arguments_desc(con DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_data_func() const { return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { - assert(!params.is_dynamic()); + OPENVINO_ASSERT(!params.is_dynamic()); auto& wgs = kd.params.workGroups; const auto desc = params.typed_desc(); auto rtp = static_cast(rt_params); - assert(rt_params != nullptr); + OPENVINO_ASSERT(rt_params != nullptr); const size_t batch = params.input_layouts[0].get_partial_shape()[0].get_length(); const size_t heads_num = desc->heads_num; @@ -541,7 +535,7 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da wgs.local = {1, 1, 1}; auto& scalars = kd.params.scalars; - const size_t partition_num = rtp->num_of_partitions; // get_partition_num(rtp->max_context_len); + const size_t partition_num = rtp->num_of_partitions; std::vector scaler_value = {partition_num}; scalars.resize(scaler_value.size()); @@ -575,6 +569,12 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp int scale_factor_i; std::memcpy(static_cast(&scale_factor_i), &scale_factor, sizeof(scale_factor)); + const uint32_t wg_k = BLOCK_WG_M; + const uint32_t wg_q = BLOCK_WG_N; + const size_t block_size = get_xattn_block_size(params); + OPENVINO_ASSERT(wg_k % block_size == 0, "wg_k should be multiple of block_size then there is no tails from block_size"); + OPENVINO_ASSERT(wg_q % block_size == 0, "wg_q should be multiple of block_size then there is no tails from block_size"); + jit.make("STRIDE", STRIDE); jit.make("HQ", desc->heads_num); jit.make("HK", desc->kv_heads_num); @@ -583,7 +583,7 @@ JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_imp jit.make("SG_N", SG_N); jit.make("BLOCK_SG_M", BLOCK_SG_M); jit.make("BLOCK_SG_N", BLOCK_SG_N); - jit.make("BLOCK_SIZE", get_xattn_block_size(params)); + jit.make("BLOCK_SIZE", block_size); jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE_XATTN); jit.add(make_jit_constant("INV_S", scale_factor_i)); jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); @@ -631,49 +631,15 @@ Arguments XAttentionEstimateGEMMQK::get_arguments_desc(const kernel_impl_params& DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { - assert(!params.is_dynamic()); + OPENVINO_ASSERT(!params.is_dynamic()); + OPENVINO_ASSERT(rt_params != nullptr); + auto rtp = static_cast(rt_params); const auto desc = params.typed_desc(); - // XAttention estimate is following afer kvcache_update. - const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; - // const size_t kv_heads_num = desc->kv_heads_num; - const size_t heads_num = desc->heads_num; - const size_t head_size = desc->k_head_size; - - auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; - auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; - - if (DEBUG_ENABLED) { // Debug - std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " - << "key_layout: " << key_layout.to_string() << ", querry_layout: " << querry_layout.to_string() << std::endl; - std::cout << "\tkey_dims = ["; - for (auto& it : key_layout.get_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tkey_pads = ["; - for (auto& it : key_layout.get_padded_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tquery_dims = ["; - for (auto& it : querry_layout.get_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - std::cout << "\tquery_pads = ["; - for (auto& it : querry_layout.get_padded_dims()) { - std::cout << static_cast(it) << ", "; - } - std::cout << "]" << std::endl; - } - - auto out_shape = params.output_layouts[0].get_shape(); - const size_t q_len = out_shape[0]; + const auto M = rtp->M; + const auto N = rtp->N; + const auto K = rtp->K; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` - const uint32_t N = static_cast(kv_len / STRIDE); - const uint32_t K = static_cast(STRIDE * head_size); auto get_simple_pitch = [](const layout& layout) { size_t pitch = 1; auto dims_padding = layout.get_padded_dims(); @@ -685,17 +651,15 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { } return pitch; }; + auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; const size_t query_pitch = get_simple_pitch(querry_layout) * STRIDE; const size_t slice_no = 0, slice = 0; - const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); - //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... const size_t WALK_HQ = desc->heads_num != desc->kv_heads_num ? 2 : 1; auto& wgs = kd.params.workGroups; - wgs.global = {N_kq_groups * (q_stride_pad / BLOCK_WG_M) * SG_N * WALK_HQ, SG_M, heads_num / WALK_HQ}; + wgs.global = {rtp->N_kq_groups * (rtp->q_stride_pad / BLOCK_WG_M) * SG_N * WALK_HQ, SG_M, desc->heads_num / WALK_HQ}; wgs.local = {SG_N, SG_M, 1}; const uint32_t q_start_strided = N - M; @@ -705,17 +669,6 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { std::vector scaler_value = {M, N, K, query_pitch, slice_no, slice, q_start_strided}; scalars.resize(scaler_value.size()); - if (DEBUG_ENABLED) { // Debug - size_t kv_len = get_kv_len(params, PagedAttentionStage::PREFILL); - size_t max_context_len = get_max_context_len(params); - size_t past_len = get_past_len(params, 0); - std::cout << "XAttentionEstimateGEMMQK::get_dispatch_data_func: " - << "N_kq_groups: " << N_kq_groups << ", q_stride_pad: " << q_stride_pad << ", scaler_value: " << PartialShape(scaler_value) - << ", kv_len: " << kv_len << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] - << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" - << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; - } - for (size_t i = 0; i < scaler_value.size(); ++i) { if (i == 4 || i == 5) { scalars[i].t = ScalarDescriptor::Types::INT32; @@ -755,61 +708,32 @@ Arguments XAttentionEstimateFindBlock::get_arguments_desc(const kernel_impl_para DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { - assert(!params.is_dynamic()); - auto& wgs = kd.params.workGroups; - + OPENVINO_ASSERT(!params.is_dynamic()); + OPENVINO_ASSERT(rt_params != nullptr); + auto rtp = static_cast(rt_params); const auto desc = params.typed_desc(); - // auto rtp = static_cast(rt_params); - - assert(rt_params != nullptr); - const uint32_t wg_k = BLOCK_WG_M; - const uint32_t wg_q = BLOCK_WG_N; - const size_t block_size = get_xattn_block_size(params); - OPENVINO_ASSERT(wg_k % block_size == 0, "wg_k should be multiple of block_size then there is no tails from block_size"); - OPENVINO_ASSERT(wg_q % block_size == 0, "wg_q should be multiple of block_size then there is no tails from block_size"); - - const size_t sum_per_n_token_in_block = static_cast(block_size / STRIDE); + auto& wgs = kd.params.workGroups; - // const size_t batch = params.input_layouts[PagedAttentionInputIdx::QUERY].get_partial_shape()[0].get_length(); const size_t heads_num = desc->heads_num; - // const size_t head_size = desc->k_head_size; auto out_shape = params.output_layouts[0].get_shape(); - const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` - const uint32_t N = static_cast(kv_len / STRIDE); - const uint32_t q_stride = M; - const uint32_t k_stride = N; - const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); - const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); - const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); - const uint32_t k_block_pad = k_block_in_group * N_kq_groups; - const uint32_t q_block_pad = ceil_div(q_len, block_size); - - const uint32_t q_block = ceil_div(q_stride, sum_per_n_token_in_block); - const uint32_t k_block = ceil_div(k_stride, sum_per_n_token_in_block); + const size_t block_size = get_xattn_block_size(params); + const size_t sum_per_n_token_in_block = static_cast(block_size / STRIDE); + const uint32_t q_block = ceil_div(rtp->M, sum_per_n_token_in_block); + const uint32_t k_block = ceil_div(rtp->N, sum_per_n_token_in_block); const float xattn_thresh = get_xattn_thresh(params); - wgs.global = {q_block_pad, heads_num, 1}; + wgs.global = {rtp->q_block_pad, heads_num, 1}; wgs.local = {1, 1, 1}; auto& scalars = kd.params.scalars; - std::vector scaler_value = {q_len, q_stride, q_stride_pad, q_block_pad, k_block_pad, k_block - q_block}; + std::vector scaler_value = {q_len, rtp->M, rtp->q_stride_pad, rtp->q_block_pad, rtp->k_block_pad, k_block - q_block}; scalars.resize(scaler_value.size() + 1); - if (DEBUG_ENABLED) { // Debug - std::cout << "XAttentionEstimateFindBlock::get_dispatch_data_func: " - << "xattn_thresh : " << xattn_thresh << " k_block: " << k_block << ", q_block: " << q_block << " q_stride: " << q_stride - << ", q_stride_pad: " << q_stride_pad << ", k_block_pad: " << k_block_pad << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " - << wgs.global[2] << "]" - << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; - } - for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::UINT32; scalars[i].v.u32 = static_cast(scaler_value[i]); @@ -849,35 +773,20 @@ Arguments XAttentionEstimatePostProc::get_arguments_desc(const kernel_impl_param DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { - assert(!params.is_dynamic()); - auto& wgs = kd.params.workGroups; - + OPENVINO_ASSERT(!params.is_dynamic()); + OPENVINO_ASSERT(rt_params != nullptr); + auto rtp = static_cast(rt_params); const auto desc = params.typed_desc(); - assert(rt_params != nullptr); + auto& wgs = kd.params.workGroups; - const size_t block_size = get_xattn_block_size(params); - const size_t heads_num = desc->heads_num; + const uint32_t q_block_pad_merged = ceil_div(rtp->q_block_pad, MERGED_Q_NUM); - auto out_shape = params.output_layouts[0].get_shape(); - const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; - const size_t q_len = out_shape[0]; - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` - const uint32_t N = static_cast(kv_len / STRIDE); - const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); - const uint32_t N_kq_groups = ceil_div(N, BLOCK_WG_N); - - const uint32_t sum_per_token_in_block = static_cast(block_size / STRIDE); - const uint32_t k_block_in_group = static_cast(BLOCK_WG_N / sum_per_token_in_block); - const uint32_t k_block_pad = k_block_in_group * N_kq_groups; - const uint32_t q_block_pad = ceil_div(q_len, block_size); - const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); - - wgs.global = {q_block_pad_merged, heads_num, 1}; + wgs.global = {q_block_pad_merged, desc->heads_num, 1}; wgs.local = {1, 1, 1}; auto& scalars = kd.params.scalars; - std::vector scaler_value = {q_stride_pad, q_block_pad, k_block_pad}; + std::vector scaler_value = {rtp->q_stride_pad, rtp->q_block_pad, rtp->k_block_pad}; scalars.resize(scaler_value.size()); for (size_t i = 0; i < scaler_value.size(); ++i) { diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 4498c70980a51c..1740c81c87c3ef 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -44,12 +44,18 @@ constexpr uint32_t MERGED_Q_NUM = PA_KV_CACHE_BLOCK_SIZE_XATTN / XATTN_BLOCK_SIZ enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; struct PagedAttentionRuntimeParams : public ImplRuntimeParams { PagedAttentionStage stage; - size_t num_of_partitions; - size_t partition_size; size_t max_context_len; - size_t paged_attention_aligned_seq_len; - size_t xattn_q_block_pad; - size_t xattn_k_block_pad; + // below are rt params for decoding + size_t num_of_partitions; + // below are rt params for xattn + size_t q_block_pad; + size_t k_block_pad; + size_t q_stride_pad; + uint32_t q_block_pad_merged; + size_t N_kq_groups; + size_t M; + size_t N; + size_t K; }; enum PagedAttentionInternBuffIdx { @@ -71,7 +77,6 @@ PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_par size_t get_max_context_len(const kernel_impl_params& params); size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); size_t get_partition_size(const bool has_xattention); -size_t get_partition_num(const size_t kv_len, const bool has_xattention); float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx = 0); bool bypass_xattn(const kernel_impl_params& impl_param); From c2bde5bb8346eb8086f9cb90b97c96ff70b1d4a6 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Mon, 20 Oct 2025 16:43:28 +0800 Subject: [PATCH 86/96] fix clang-format style issues --- .../src/graph/impls/cm/paged_attention.cpp | 34 +++++++---------- .../graph/impls/cm/paged_attention_gen.cpp | 37 +++++++++---------- .../graph/impls/cm/paged_attention_gen.hpp | 12 +++--- .../graph/impls/ocl_v2/primitive_ocl_base.hpp | 10 ++--- 4 files changed, 39 insertions(+), 54 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index fb49c6cc556da1..4de0b8ce5efefd 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -52,14 +52,13 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { } } - void update_xattn_rt_params(const primitive_inst& instance) { - const auto& params = *instance.get_impl_params(); - OPENVINO_ASSERT(!params.is_dynamic()); + void update_xattn_rt_params(const kernel_impl_params& params) { const auto desc = params.typed_desc(); + // XAttention estimate is following afer kvcache_update. auto out_shape = params.output_layouts[0].get_shape(); const size_t block_size = get_xattn_block_size(params); - const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE; + const size_t kv_len = get_max_context_len(params); const size_t q_len = out_shape[0]; const size_t N = kv_len / STRIDE; const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); @@ -73,13 +72,10 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { rt_params->q_block_pad = q_block_pad; rt_params->k_block_pad = k_block_pad; - // XAttention estimate is following afer kvcache_update. const size_t head_size = desc->k_head_size; - auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; - - const uint32_t M = static_cast(q_len / STRIDE); //# will slient drop the tails which is less than `stride` - const uint32_t K = static_cast(STRIDE * head_size); + const auto M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const auto K = STRIDE * head_size; const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); @@ -97,28 +93,26 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { } const auto& params = *instance.get_impl_params(); + OPENVINO_ASSERT(!params.is_dynamic()); auto rt_params = static_cast(m_rt_params.get()); const auto& desc = params.typed_desc(); rt_params->stage = get_paged_attention_stage(params); const auto max_context_len = get_max_context_len(params); rt_params->max_context_len = max_context_len; - GPU_DEBUG_TRACE_DETAIL << "update_rt_params for stage: " << static_cast(rt_params->stage) - << " max_context_len: " << rt_params->max_context_len << std::endl; + GPU_DEBUG_TRACE_DETAIL << "update_rt_params for stage: " << static_cast(rt_params->stage) << " max_context_len: " << rt_params->max_context_len + << std::endl; if (rt_params->stage == PagedAttentionStage::GENERATE) { auto partition_size = get_partition_size(desc->has_xattention); rt_params->num_of_partitions = ceil_div(max_context_len, partition_size); - GPU_DEBUG_TRACE_DETAIL << " partition_size: " << partition_size - << " num_of_partitions: " << rt_params->num_of_partitions << std::endl; + GPU_DEBUG_TRACE_DETAIL << " partition_size: " << partition_size << " num_of_partitions: " << rt_params->num_of_partitions << std::endl; } else { if (desc->has_xattention) { - update_xattn_rt_params(instance); + update_xattn_rt_params(params); } } - - } // update impl_parameter and rt_parameter @@ -184,8 +178,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { internal_buffers.emplace_back(tmp_out_elements_count, ov::element::f32); // 0: intermediate partition output internal_buffers.emplace_back(buf_elements_count, ov::element::f32); // 1: softmax exp_sums - GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 4 - << " exp_sums=" << buf_elements_count * 4 << std::endl; + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 4 << " exp_sums=" << buf_elements_count * 4 << std::endl; } else { internal_buffers.emplace_back(16, ov::element::f32); // 0: intermediate partition output internal_buffers.emplace_back(16, ov::element::f32); // 1: softmax exp_sums @@ -205,9 +198,8 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { internal_buffers.emplace_back(count_elements_mask_merged, ov::element::boolean); // 5: sparse_block_mask_wg GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: count_kq_max_wg=" << count_kq_max_wg * 4 - << " count_kq_exp_partial_sum=" << count_kq_exp_partial_sum * 4 - << " count_elements_mask=" << count_elements_mask * 1 - << " count_elements_mask_merged=" << count_kq_exp_partial_sum * 1 << std::endl; + << " count_kq_exp_partial_sum=" << count_kq_exp_partial_sum * 4 << " count_elements_mask=" << count_elements_mask * 1 + << " count_elements_mask_merged=" << count_kq_exp_partial_sum * 1 << std::endl; } } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index fa0c58f7992986..b539f1bda01c4a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -268,10 +268,10 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() size_t value_offset = get_simple_offset(value_layout); if (DEBUG_ENABLED) { // Debug - std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " - << "kv_len: " << kv_len << ", key_pitch: " << key_pitch << ", key_offset: " << key_offset << ", value_pitch: " << value_pitch - << ", value_offset: " << value_offset << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" - << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " << "kv_len: " << kv_len << ", key_pitch: " << key_pitch + << ", key_offset: " << key_offset << ", value_pitch: " << value_pitch << ", value_offset: " << value_offset << ", gws: [" << wgs.global[0] + << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] + << "]" << std::endl; } // TODO: support multiple sequences size_t batch_size_in_sequences = 1; @@ -304,7 +304,7 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); if (desc->has_xattention) { - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK_MERGED}); // sparse_block_mask_wg } @@ -373,8 +373,7 @@ DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() con std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: \n" << "\tbatch: " << batch << ", heads_num: " << heads_num << ", q_len: " << q_len << ", q_step: " << q_step << ", wg_seq_len: " << wg_seq_len << ", wg_count: " << wg_count << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " - << wgs.global[2] << "]" - << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } auto num_scalers = desc->has_xattention ? 4 : 1; scalars.resize(num_scalers); @@ -443,7 +442,7 @@ Arguments PagedAttentionGeneratorSingleToken::get_arguments_desc(const kernel_im // outputs args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_PARTITIONOUT}); // partition output - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse output // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len==1 @@ -476,11 +475,10 @@ DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() co size_t kv_len = get_kv_len(params, PagedAttentionStage::GENERATE); size_t max_context_len = get_max_context_len(params); size_t past_len = get_past_len(params, 0); - std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " - << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", kv_len: " << kv_len - << ", max_context_len = " << max_context_len << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] - << ", " << wgs.global[2] << "]" - << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " << "batch: " << batch << ", heads_num: " << heads_num + << ", partition_num: " << partition_num << ", kv_len: " << kv_len << ", max_context_len = " << max_context_len + << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" + << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } for (size_t i = 0; i < scaler_value.size(); ++i) { scalars[i].t = ScalarDescriptor::Types::INT32; @@ -509,8 +507,8 @@ Arguments PagedAttentionGeneratorSingleTokenFinalization::get_arguments_desc(con OPENVINO_THROW("[GPU][CM] PagedAttentionGeneratorSingleTokenFinalization with scores output is not supported yet"); args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_PARTITIONOUT}); // partition data - args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // output - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse // scalar args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // kv_partition_num @@ -540,9 +538,8 @@ DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_da scalars.resize(scaler_value.size()); if (DEBUG_ENABLED) { // Debug - std::cout << "PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_data_func: " - << "batch: " << batch << ", heads_num: " << heads_num << ", partition_num: " << partition_num << ", gws: [" << wgs.global[0] << ", " - << wgs.global[1] << ", " << wgs.global[2] << "]" + std::cout << "PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_data_func: " << "batch: " << batch << ", heads_num: " << heads_num + << ", partition_num: " << partition_num << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; } for (size_t i = 0; i < scaler_value.size(); ++i) { @@ -614,7 +611,7 @@ Arguments XAttentionEstimateGEMMQK::get_arguments_desc(const kernel_impl_params& args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins // outputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_EXPSUMS}); // kq_exp_partial_sum // scalar @@ -688,7 +685,7 @@ Arguments XAttentionEstimateFindBlock::get_arguments_desc(const kernel_impl_para Arguments args; // inputs - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_EXPSUMS}); // kq_exp_partial_sum // outputs diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index 1740c81c87c3ef..d4f6952cb2d35a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -60,13 +60,13 @@ struct PagedAttentionRuntimeParams : public ImplRuntimeParams { enum PagedAttentionInternBuffIdx { // for decoding kernels - DECODE_PARTITIONOUT = 0, // 0: intermediate partition output - DECODE_EXPSUMS = 1, // 1: softmax exp_sums + DECODE_PARTITIONOUT = 0, // 0: intermediate partition output + DECODE_EXPSUMS = 1, // 1: softmax exp_sums // for xattn kernels - XATTN_GEMMQK_MAX = 2, // 2: kq_max_wg - XATTN_GEMMQK_EXPSUMS = 3, // 3: kq_exp_partial_sum - XATTN_BLOCKMASK = 4, // 4: sparse_block_mask - XATTN_BLOCKMASK_MERGED = 5, // 5: sparse_block_mask_wg + XATTN_GEMMQK_MAX = 2, // 2: kq_max_wg + XATTN_GEMMQK_EXPSUMS = 3, // 3: kq_exp_partial_sum + XATTN_BLOCKMASK = 4, // 4: sparse_block_mask + XATTN_BLOCKMASK_MERGED = 5, // 5: sparse_block_mask_wg }; //----------------------------------------------------------------------------------------------------------------- diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp index 6507e97188ff97..5fd64ded54411f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp @@ -251,13 +251,9 @@ struct PrimitiveImplOCL : public cldnn::primitive_impl { GPU_DEBUG_TRACE_DETAIL << "\t" << i << ": type = " << static_cast(params.arguments[i].t) << ", index = " << params.arguments[i].index << '\n'; } - GPU_DEBUG_TRACE_DETAIL << "Memory buffers:" - << "shape_info=" << args.shape_info << " " - << "inputs=" << args.inputs.size() << " " - << "outputs=" << args.outputs.size() << " " - << "intermediates=" << args.intermediates.size() << " " - << "weights=" << args.weights << " " - << "scalars=" << (args.scalars ? args.scalars->size() : 0) << "\n"; + GPU_DEBUG_TRACE_DETAIL << "Memory buffers:" << "shape_info=" << args.shape_info << " " << "inputs=" << args.inputs.size() << " " + << "outputs=" << args.outputs.size() << " " << "intermediates=" << args.intermediates.size() << " " + << "weights=" << args.weights << " " << "scalars=" << (args.scalars ? args.scalars->size() : 0) << "\n"; stream.set_arguments(*stage.kernel, params, args); kd.need_args_update = false; } From 554ebf43d8f9a6c3807ad846fdac94d16a25ef7d Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 21 Oct 2025 14:45:52 +0800 Subject: [PATCH 87/96] merge xattention tests into paged_attention tests (#63) --- .../test_cases/paged_attention_gpu_test.cpp | 580 +++++--- .../unit/test_cases/xattention_gpu_test.cpp | 1273 ----------------- 2 files changed, 381 insertions(+), 1472 deletions(-) delete mode 100644 src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index c31a8c0fd69056..9de30c473d7806 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1,20 +1,23 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // -#include "test_utils.h" -#include "random_generator.hpp" - #include #include #include -#include #include +#include #include #include #include #include #include +#include + +#include "openvino/runtime/tensor.hpp" +#include "primitive_inst.h" +#include "random_generator.hpp" +#include "test_utils.h" using namespace cldnn; using namespace ov::intel_gpu; @@ -131,7 +134,8 @@ struct PagedAttentionManager { bool kv_cache_compression, ov::internal::CacheQuantMode key_cache_quant_mode, bool has_score_aggregation, - CacheRotationDescriptor rotation_config) + CacheRotationDescriptor rotation_config, + std::vector threshold) : num_heads(num_heads) , num_kv_heads(num_kv_heads) , k_head_size(k_head_size) @@ -149,6 +153,9 @@ struct PagedAttentionManager { // init subsequence_begins and block_indices_begins subsequence_begins.push_back(0); block_indices_begins.push_back(0); + for (int i = 0; i < static_cast(threshold.size()); i++) { + xattention_threshold.emplace_back(static_cast(threshold[i])); + } int max_len = 0; for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { @@ -224,10 +231,7 @@ struct PagedAttentionManager { return get_QKV_memory(value_data, num_kv_heads, v_head_size, true); } -/* TODO: These CM kernels test should be run only if CM compiler is ready on the system */ -#define ENABLE_PA_CM_PATH 1 // Define it here to make the build passed -#if ENABLE_PA_CM_PATH - memory::ptr get_key_cache_memory() { + memory::ptr get_key_cache_memory_cm() { auto key_cache_dt = data_types::f16; auto adjusted_head_size = k_head_size; if (kv_cache_compression) { @@ -280,7 +284,6 @@ struct PagedAttentionManager { return memory; } -#else memory::ptr get_key_cache_memory() { auto key_cache_dt = data_types::f16; auto adjusted_head_size = k_head_size; @@ -378,7 +381,6 @@ struct PagedAttentionManager { return memory; } -#endif memory::ptr get_value_cache_memory() { auto value_cache_dt = data_types::f16; @@ -611,6 +613,26 @@ struct PagedAttentionManager { return data; } + static std::vector generate_realistic_data(size_t num_heads, size_t tokens_num, size_t k_head_size) { + std::vector data(num_heads * tokens_num * k_head_size); + + std::mt19937 gen(1234); + std::normal_distribution dist(0.0f, 0.1f); + + for (size_t h = 0; h < num_heads; ++h) { + for (size_t t = 0; t < tokens_num; ++t) { + for (size_t d = 0; d < k_head_size; ++d) { + float val = dist(gen); + if (t > 0) + val = 0.8f * val + 0.2f * static_cast(data[h * tokens_num * k_head_size + (t - 1) * k_head_size + d]); + data[h * tokens_num * k_head_size + t * k_head_size + d] = static_cast(val); + } + } + } + + return data; + } + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { const size_t total_elements_num = per_block ? rotated_blocks_num : rotated_blocks_num * block_size; @@ -685,7 +707,7 @@ struct PagedAttentionReference { , test_engine(pam.test_engine) , test_stream(pam.test_stream) {} - std::pair, std::vector> get_reference() { + std::pair, std::vector> get_reference(bool has_xattention, std::vector threshold) { std::vector ref_data_output; std::vector ref_scores_output; @@ -722,7 +744,8 @@ struct PagedAttentionReference { auto window_size = pam.has_score_aggregation ? pam.score_aggregation[i] : 1; - auto subsequence_ref_results = run_reference(pam.query_data[i], + auto subsequence_ref_results = run_reference(has_xattention, + pam.query_data[i], key_data, pam.value_data[i], subsequence_desc.num_tokens, @@ -733,7 +756,8 @@ struct PagedAttentionReference { pam.v_head_size, window_size, pam.sliding_window_size, - pam.get_default_scale()); + pam.get_default_scale(), + static_cast(threshold[i])); // concatenate all subsequences into one vector ref_data_output.insert(ref_data_output.end(), @@ -748,19 +772,22 @@ struct PagedAttentionReference { } private: - std::pair, std::vector> - run_reference(const std::vector& query_data, - const std::vector& key_data, - const std::vector& value_data, - int num_queries, - int num_keys, - int num_heads, - int num_kv_heads, - int k_head_size, - int v_head_size, - int window_size, - int sliding_window_size, - float scale) { + std::pair, std::vector> run_reference(bool has_xattention, + const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + int num_queries, + int num_keys, + int num_heads, + int num_kv_heads, + int k_head_size, + int v_head_size, + int window_size, + int sliding_window_size, + float scale, + double threshold = 0.9, + size_t block_size = 128, + size_t stride = 16) { auto query_shape = ov::PartialShape{1, num_queries, num_heads, k_head_size}; auto key_shape = ov::PartialShape{1, num_keys, num_kv_heads, k_head_size}; auto value_shape = ov::PartialShape{1, num_keys, num_kv_heads, v_head_size}; @@ -782,7 +809,6 @@ struct PagedAttentionReference { auto query_mem = test_engine.allocate_memory(query_layout); auto key_mem = test_engine.allocate_memory(key_layout); auto value_mem = test_engine.allocate_memory(value_layout); - auto mask_mem = get_mask_mem(num_queries, num_keys, num_heads, sliding_window_size); auto scale_mem = test_engine.allocate_memory(scale_layout); set_values(query_mem, query_data); @@ -790,6 +816,54 @@ struct PagedAttentionReference { set_values(value_mem, value_data); set_values(scale_mem, {static_cast(scale)}); + ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; + if (num_queries >= static_cast(block_size) && has_xattention) { + auto reorder_qhk_to_hqd = [&](const std::vector& src, int outer_len, int num_heads, int head_dim) { + std::vector dst(num_heads * outer_len * head_dim); + for (int h = 0; h < num_heads; ++h) { + size_t dst_h_off = static_cast(h) * outer_len * head_dim; + for (int i = 0; i < outer_len; ++i) { + size_t src_off = static_cast(i) * num_heads * head_dim + static_cast(h) * head_dim; + std::copy_n(&src[src_off], head_dim, &dst[dst_h_off + static_cast(i) * head_dim]); + } + } + return dst; + }; + + const auto query_data_3d = reorder_qhk_to_hqd(query_data, num_queries, num_heads, k_head_size); + const auto key_data_3d = reorder_qhk_to_hqd(key_data, num_keys, num_heads, k_head_size); + const size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; + const size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; + + std::vector query_padded(num_heads * padded_q * k_head_size, 0.f); + std::vector key_padded(num_heads * padded_k * k_head_size, 0.f); + + for (int h = 0; h < num_heads; ++h) { + const auto* q_src = &query_data_3d[h * num_queries * k_head_size]; + const auto* k_src = &key_data_3d[h * num_keys * k_head_size]; + auto* q_dst = &query_padded[h * padded_q * k_head_size]; + auto* k_dst = &key_padded[h * padded_k * k_head_size]; + + std::transform(q_src, q_src + num_queries * k_head_size, q_dst, [](ov::float16 v) { + return static_cast(v); + }); + std::transform(k_src, k_src + num_keys * k_head_size, k_dst, [](ov::float16 v) { + return static_cast(v); + }); + } + ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); + retained_blocks = selector.select_blocks(query_padded.data(), + {static_cast(num_heads), padded_q, static_cast(k_head_size)}, + key_padded.data(), + {static_cast(num_heads), padded_k, static_cast(k_head_size)}); + } + auto mask_mem = get_mask_mem_combined_multi_head(num_queries, + num_keys, + num_heads, + num_kv_heads, + sliding_window_size, + retained_blocks, + static_cast(block_size)); topology topology; if (num_heads == num_kv_heads) { topology.add(input_layout("query", query_layout), @@ -885,69 +959,102 @@ struct PagedAttentionReference { return output_data; } - memory::ptr get_mask_mem(int num_queries, int num_keys, int num_heads, int sliding_window_size) { - /* - * Two kinds of masks: - * - * Case 1 (N == K): - * num_queries = N - * num_keys = K = N - * k_head_size = H - * Q [N, H] * K[H, N] - * QK [N, N] - * 0 1 N - * 0 [ 0, MIN, .., MIN ] - * 1 [ 0, 0, .., MIN ] - * [ .., .., .., MIN ] - * N [ 0, 0, .., 0 ] - * - * Case 2 (N != K): - * num_queries = N - * num_keys = K - * k_head_size = H - * past_len = P = K - N + 1 - * Q [N, H] * K[H, K] - * QK [N, K] - * 0 1 2 P .. K - * 0 [ 0, 0, 0, MIN, MIN, MIN ] - * 1 [ 0, 0, 0, 0, MIN, MIN ] - * [ .., .., .., .., .., MIN ] - * N [ 0, 0, 0, 0, .., 0 ] - * - * Shapes: - * Q [1, num_heads, num_queries, k_head_size] - * K [1, num_heads, k_head_size, num_keys] - * Q*K [1, num_heads, num_queries, num_keys] - */ - - auto mask_shape = ov::PartialShape{ 1, 1, num_queries, num_keys }; + memory::ptr get_mask_mem_combined_multi_head(int num_queries, + int num_keys, + int num_heads, + int num_kv_heads, + int sliding_window_size, + const ov::reference::XAttentionRetainedBlockIndicesForAllHeads& retained_blocks, + int block_size) { + int heads_per_kv = num_heads / num_kv_heads; + + ov::PartialShape mask_shape; + if (retained_blocks.empty()) { + mask_shape = ov::PartialShape{1, 1, num_queries, num_keys}; + } else if (num_heads == num_kv_heads) { + mask_shape = ov::PartialShape{1, num_heads, num_queries, num_keys}; + } else { + mask_shape = ov::PartialShape{num_kv_heads, heads_per_kv, num_queries, num_keys}; + } + auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; auto mask_mem = test_engine.allocate_memory(mask_layout); - mem_lock mem_ptr(mask_mem, test_stream); - if (sliding_window_size == 0) { - int past_len = num_keys - num_queries + 1; - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest() - : ov::float16(0.f); + size_t total_elems = mask_layout.count(); + for (size_t i = 0; i < total_elems; ++i) + mem_ptr[i] = std::numeric_limits::lowest(); + if (retained_blocks.empty()) { + if (sliding_window_size == 0) { + int past_len = num_keys - num_queries + 1; + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest() : ov::float16(0.f); + } + } + } else { + int sliding_left = num_keys - num_queries - sliding_window_size + 1; + int past_len = num_keys - num_queries + 1; + + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + bool is_min; + if (num_queries == num_keys) { + is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; + } else { + is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; + } + + mem_ptr[i * num_keys + j] = is_min ? std::numeric_limits::lowest() : ov::float16(0.f); + } } } } else { - int sliding_left = num_keys - num_queries - sliding_window_size + 1; - int past_len = num_keys - num_queries + 1; - - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - bool is_min; - if (num_queries == num_keys) { - is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; + for (int h = 0; h < num_heads; ++h) { + int kv_idx = (num_heads == num_kv_heads) ? 0 : (h / heads_per_kv); + int head_in_kv = (num_heads == num_kv_heads) ? h : (h % heads_per_kv); + + size_t head_offset = (static_cast(kv_idx) * heads_per_kv + static_cast(head_in_kv)) * static_cast(num_queries) * + static_cast(num_keys); + + for (int i = 0; i < num_queries; i++) { + int left_idx = 0; + int right_idx = 0; + + if (sliding_window_size == 0) { + int past_len = num_keys - num_queries + 1; + right_idx = past_len + i - 1; + left_idx = 0; } else { - is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; + int sliding_left = num_keys - num_queries - sliding_window_size + 1; + int past_len = num_keys - num_queries + 1; + if (num_queries == num_keys) { + left_idx = sliding_left + i; + right_idx = i; + } else { + left_idx = sliding_left + i; + right_idx = past_len + i - 1; + } } - mem_ptr[i * num_keys + j] = is_min ? std::numeric_limits::lowest() : ov::float16(0.f); + left_idx = std::max(0, left_idx); + right_idx = std::min(num_keys - 1, right_idx); + + for (const auto& [q_block_idx, k_block_idx] : retained_blocks[h]) { + int q_start = q_block_idx * block_size; + int q_end = std::min(q_start + block_size, num_queries); + int k_start = k_block_idx * block_size; + int k_end = std::min(k_start + block_size, num_keys); + + if (i < q_start || i >= q_end) + continue; + + for (int j = k_start; j < k_end; j++) { + if (j >= left_idx && j <= right_idx) { + mem_ptr[head_offset + i * num_keys + j] = ov::float16(0.f); + } + } + } } } } @@ -1019,7 +1126,8 @@ struct PagedAttentionTest : public ::testing::TestWithParam { p.kv_cache_compression, p.key_cache_quant_mode, p.scores_mode == ScoresMode::SNAPKV, - p.rotation_config); + p.rotation_config, + p.threshold); if (p.kv_cache_compression) tolerance = 25e-3; @@ -1027,8 +1135,13 @@ struct PagedAttentionTest : public ::testing::TestWithParam { auto query_mem = pam.get_query_memory(); auto key_mem = pam.get_key_memory(); auto value_mem = pam.get_value_memory(); - - auto key_cache_mem = pam.get_key_cache_memory(); + + memory::ptr key_cache_mem; + if (p.has_xattention) { + key_cache_mem = pam.get_key_cache_memory_cm(); + } else { + key_cache_mem = pam.get_key_cache_memory(); + } auto value_cache_mem = pam.get_value_cache_memory(); auto past_lens_mem = pam.get_past_lens_memory(); @@ -1170,6 +1283,9 @@ struct PagedAttentionTest : public ::testing::TestWithParam { pa_prim.has_score_aggregation = p.scores_mode == ScoresMode::SNAPKV; pa_prim.sliding_window = p.sliding_window_size; pa_prim.is_key_by_channel = (p.key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL); + if (p.has_xattention) { + pa_prim.has_xattention = true; + } topology topology; @@ -1248,8 +1364,12 @@ struct PagedAttentionTest : public ::testing::TestWithParam { if (p.scores_mode != ScoresMode::DISABLED) { output_scores_mem = outputs.at("output_scores").get_memory(); } - auto ref_data = PagedAttentionReference(pam).get_reference(); - compare(output_data_mem, output_scores_mem, ref_data); + auto ref_data = PagedAttentionReference(pam).get_reference(p.has_xattention, p.threshold); + if (p.has_xattention) { + compare_xattention(output_data_mem, output_scores_mem, ref_data); + } else { + compare(output_data_mem, output_scores_mem, ref_data); + } } void compare(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) { @@ -1269,6 +1389,42 @@ struct PagedAttentionTest : public ::testing::TestWithParam { } } } + + void compare_xattention(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) { + if (data_output_mem) { + ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); + mem_lock mem_ptr(data_output_mem, get_test_stream()); + int mismatch_count = 0; + for (size_t i = 0; i < data_output_mem->count(); i++) { + if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.first[i])) > tolerance) { + mismatch_count++; + } + } + EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.04)); + } + + if (scores_output_mem) { + ASSERT_EQ(scores_output_mem->count(), ref_data.second.size()); + mem_lock mem_ptr(scores_output_mem, get_test_stream()); + int mismatch_count = 0; + for (size_t i = 0; i < scores_output_mem->count(); i++) { + if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.second[i])) > tolerance) { + mismatch_count++; + } + } + EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.04)); + } + } + + static bool check_cm_available() { + auto& engine = get_test_engine(); + ExecutionConfig config = get_test_default_config(engine); + if (!cldnn::check_cm_jit_support(engine, config) || !engine.get_device_info().supports_immad) { + return false; + } + + return true; + } }; struct paged_attention_test_params { @@ -1278,7 +1434,9 @@ struct paged_attention_test_params { int k_head_size; int v_head_size; int block_size; + std::vector threshold; int sliding_window_size; + bool has_xattention; bool kv_cache_compression; ov::internal::CacheQuantMode key_cache_quant_mode; bool dynamic_paddings; @@ -1294,6 +1452,15 @@ TEST_P(paged_attention_test, basic) { execute(p); } +class xattention_test : public PagedAttentionTest {}; +TEST_P(xattention_test, basic) { + if (!check_cm_available()) + GTEST_SKIP(); + auto p = GetParam(); + + execute(p); +} + const auto ENABLE_CACHE_COMPRESSION = true; const auto DISABLE_CACHE_COMPRESSION = false; const auto DISABLE_SCORES = ScoresMode::DISABLED; @@ -1307,127 +1474,142 @@ const auto DYNAMIC_INPUT_PAD = true; const auto ENABLE_FA_V2 = false; const auto DISABLE_FA_V2 = true; -#ifndef ENABLE_PA_CM_PATH INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ /* with scores output, use SnapKV */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores output */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 16, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 128, 96, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 48, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 96, 128, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 32, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 96, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 48, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 16, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 16, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 128, 96, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 48, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 96, 128, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 32, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 96, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 48, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 16, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* without scores output, dynamic input query paddings */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_block rotation */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 48, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 48, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_token rotation */ - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 128, 192, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 128, 192, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* without scores output, dynamic input query paddings, KV-cache compression */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_block rotation, KV-cache compression */ - paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token /* with scores, per_token rotation, KV-cache compression */ - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* With sliding windows */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 6, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 6, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, 20, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, 20, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 8, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, 128, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 4, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{5, 10}}, 2, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{8, 8}}, 16, 16, 256, 256, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{34, 0}}, 2, 2, 32, 32, 16, 2, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1, 1008}}, 32, 32, 128, 128, 16, 6, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{6, 20}}, 2, 2, 128, 128, 16, 8, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{254, 10}}, 32, 8, 128, 128, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching, GQA - paged_attention_test_params{ {{84, 2}}, 32, 32, 128, 128, 16, 16, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 492}}, 32, 32, 32, 32, 16, 32, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 492}}, 16, 16, 64, 64, 16, 64, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 492}}, 8, 8, 128, 128, 16, 128, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{2, 30}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{232, 24}}, 2, 2, 512, 512, 16, 32, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 592}}, 32, 32, 128, 128, 16, 64, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 692}}, 32, 32, 128, 128, 16, 128, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 792}}, 32, 32, 128, 128, 16, 256, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 2, 64, 64, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 6, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 6, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, {100.0}, 20, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, {100.0}, 20, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 8, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, {100.0}, 128, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 4, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{5, 10}}, 2, 2, 64, 64, 16, {100.0}, 2, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{8, 8}}, 16, 16, 256, 256, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{34, 0}}, 2, 2, 32, 32, 16, {100.0}, 2, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1, 1008}}, 32, 32, 128, 128, 16, {100.0}, 6, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{6, 20}}, 2, 2, 128, 128, 16, {100.0}, 8, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{254, 10}}, 32, 8, 128, 128, 16, {100.0}, 10, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching, GQA + paged_attention_test_params{ {{84, 2}}, 32, 32, 128, 128, 16, {100.0}, 16, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 492}}, 32, 32, 32, 32, 16, {100.0}, 32, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 492}}, 16, 16, 64, 64, 16, {100.0}, 64, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 492}}, 8, 8, 128, 128, 16, {100.0}, 128, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{2, 30}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{232, 24}}, 2, 2, 512, 512, 16, {100.0}, 32, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 592}}, 32, 32, 128, 128, 16, {100.0}, 64, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 692}}, 32, 32, 128, 128, 16, {100.0}, 128, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 792}}, 32, 32, 128, 128, 16, {100.0}, 256, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 10, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token +})); + +INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test, ::testing::ValuesIn(std::vector{ + /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, + token_size>=32, disable_mix_mode */ + paged_attention_test_params{ {{32, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{2048, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + + paged_attention_test_params{ {{1, 31}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 32}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 1023}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 1024}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 127}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 128}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 129}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 32}}, 28, 28, 128, 128, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token })); -#endif diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp deleted file mode 100644 index 2eae9c3908cd20..00000000000000 --- a/src/plugins/intel_gpu/tests/unit/test_cases/xattention_gpu_test.cpp +++ /dev/null @@ -1,1273 +0,0 @@ -// Copyright (C) 2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "openvino/runtime/tensor.hpp" -#include "primitive_inst.h" -#include "random_generator.hpp" -#include "test_utils.h" - -using namespace cldnn; -using namespace ov::intel_gpu; -using namespace ::tests; - -enum class XAttentionScoresMode { DISABLED = 0, LAST_TOKEN, SNAPKV }; - -struct XAttentionSubsequenceDescriptor { - int num_tokens; - int past_len; -}; - -struct XAttentionCacheRotationDescriptor { - bool apply_rotation; - // configures 2nd dimension of rotation_deltas - // if per_block is true, single value is used for all tokens inside the block - // otherwise, each token uses an independent value - bool per_block; -}; - -struct XAttentionManager { - int num_heads; - int num_kv_heads; - int k_head_size; - int v_head_size; - int block_size; - int sliding_window_size; - bool kv_cache_compression; - ov::internal::CacheQuantMode key_cache_quant_mode; - bool has_score_aggregation; - XAttentionCacheRotationDescriptor rotation_config; - std::vector subsequence_descs; - - // per-subsequence QKV inputs - std::vector> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} - std::vector> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} - std::vector> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} - - // common PA inputs - std::vector past_lens; - std::vector subsequence_begins; - std::vector block_indices; - std::vector block_indices_begins; - std::vector max_context_len; - std::vector score_aggregation_window; - - // score aggregation related inputs - std::vector score_aggregation; - - // rotation related inputs - std::vector rotated_block_indices; - std::vector rotation_deltas; - std::vector rotation_trig_lut; - - std::vector xattention_threshold; - std::vector xattention_block_size; - std::vector xattention_stride; - - std::vector sinks; - - cldnn::engine& test_engine; - cldnn::stream& test_stream; - tests::random_generator& rg; - - XAttentionManager(tests::random_generator& rg, - cldnn::engine& engine, - cldnn::stream& stream, - const std::vector& subsequence_descs, - int num_heads, - int num_kv_heads, - int k_head_size, - int v_head_size, - int block_size, - int sliding_window_size, - bool kv_cache_compression, - ov::internal::CacheQuantMode key_cache_quant_mode, - bool has_score_aggregation, - XAttentionCacheRotationDescriptor rotation_config, - std::vector threshold) - : num_heads(num_heads), - num_kv_heads(num_kv_heads), - k_head_size(k_head_size), - v_head_size(v_head_size), - block_size(block_size), - sliding_window_size(sliding_window_size), - kv_cache_compression(kv_cache_compression), - key_cache_quant_mode(key_cache_quant_mode), - has_score_aggregation(has_score_aggregation), - rotation_config(rotation_config), - subsequence_descs(subsequence_descs), - test_engine(engine), - test_stream(stream), - rg(rg) { - // init subsequence_begins and block_indices_begins - subsequence_begins.push_back(0); - block_indices_begins.push_back(0); - for (int i = 0; i < static_cast(threshold.size()); i++) { - xattention_threshold.emplace_back(static_cast(threshold[i])); - } - - int max_len = 0; - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - const auto& subsequence_desc = subsequence_descs[i]; - max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); - - query_data.push_back(generate_realistic_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); - key_data.push_back(generate_realistic_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); - value_data.push_back(generate_realistic_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); - - past_lens.push_back(subsequence_desc.past_len); - int subsequence_start_pos = subsequence_begins[i]; - int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens; - subsequence_begins.push_back(subsequence_end_pos); - - int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len; - int required_blocks = ceil_div(subsequence_length, block_size); - int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1; - int end_block_idx = start_block_idx + required_blocks; - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - block_indices.push_back(block_idx); - } - - int block_indices_start_pos = block_indices_begins[i]; - int block_indices_end_pos = block_indices_start_pos + required_blocks; - block_indices_begins.push_back(block_indices_end_pos); - } - max_context_len.push_back(max_len); - - if (rotation_config.apply_rotation) { - // iterate over KV-cache blocks and apply cache rotation to every second - // fully occupied block - for (size_t i = 0; i < subsequence_descs.size(); i++) { - const auto& subsequence_desc = subsequence_descs[i]; - int past_len = subsequence_desc.past_len; - int start_block_idx = block_indices_begins[i]; - for (int block_idx = 1; block_idx < past_len / block_size; block_idx++) { - if (block_idx % 2 != 0) { - rotated_block_indices.push_back(start_block_idx + block_idx); - } - } - } - - if (!rotated_block_indices.empty()) { - rotation_deltas = generate_rotation_deltas_data(rg, max_context_len[0], rotated_block_indices.size(), block_size, rotation_config.per_block); - rotation_trig_lut = generate_rotation_trig_lut_data(rg, max_context_len[0], k_head_size); - } - } - - if (has_score_aggregation) { - for (const auto& subsequence_desc : subsequence_descs) { - const auto max_tokens = 10; - auto max_window_size = std::min(subsequence_desc.num_tokens, max_tokens); - auto window_size = rg.generate_random_val(1, max_window_size); - score_aggregation.push_back(window_size); - } - } - } - - memory::ptr get_query_memory() { - return get_QKV_memory(query_data, num_heads, k_head_size, false); - } - - memory::ptr get_key_memory() { - return get_QKV_memory(key_data, num_kv_heads, k_head_size, true); - } - - memory::ptr get_value_memory() { - return get_QKV_memory(value_data, num_kv_heads, v_head_size, true); - } - - memory::ptr get_key_cache_memory() { - auto key_cache_dt = data_types::f16; - auto adjusted_head_size = k_head_size; - if (kv_cache_compression) { - key_cache_dt = data_types::i8; - adjusted_head_size += 4; - } - - auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{num_blocks, num_kv_heads, block_size, adjusted_head_size}; - auto key_cache_layout = layout{key_cache_shape, key_cache_dt, format::bfyx}; - auto memory = test_engine.allocate_memory(key_cache_layout); - - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + input_token_offset * num_kv_heads * v_head_size + head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); - - // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] - size_t output_block_offset = - (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); - - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_kv_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + token_idx * v_head_size; - - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } - } - } - } - } - } - - return memory; - } - - memory::ptr get_value_cache_memory() { - auto value_cache_dt = data_types::f16; - auto adjusted_head_size = v_head_size; - if (kv_cache_compression) { - value_cache_dt = data_types::i8; - adjusted_head_size += 4; - } - - auto num_blocks = block_indices.back() + 1; - auto value_cache_shape = ov::PartialShape{num_blocks, num_kv_heads, block_size, adjusted_head_size}; - auto value_cache_layout = layout{value_cache_shape, value_cache_dt, format::bfyx}; - auto memory = test_engine.allocate_memory(value_cache_layout); - - for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = value_data[i].data() + input_token_offset * num_kv_heads * v_head_size + head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); - - // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] - size_t output_block_offset = - (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); - - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_kv_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + token_idx * v_head_size; - - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } - } - } - } - } - } - - return memory; - } - - memory::ptr get_past_lens_memory() { - return get_memory_from_vec(past_lens); - } - - memory::ptr get_subsequence_begins_memory() { - return get_memory_from_vec(subsequence_begins); - } - - memory::ptr get_block_indices_memory() { - return get_memory_from_vec(block_indices); - } - - memory::ptr get_block_indices_begins_memory() { - return get_memory_from_vec(block_indices_begins); - } - - memory::ptr get_scale_memory() { - std::vector scale = {ov::float16(get_default_scale())}; - return get_memory_from_vec(scale); - } - - memory::ptr get_sliding_window_memory() { - std::vector sliding_window = {0}; - return get_memory_from_vec(sliding_window); - } - - memory::ptr get_alibi_memory() { - std::vector alibi; - return get_memory_from_vec(alibi); - } - - memory::ptr get_max_context_len_memory() { - return get_memory_from_vec(max_context_len); - } - - memory::ptr get_score_aggregation() { - return get_memory_from_vec(score_aggregation); - } - - memory::ptr get_rotated_block_indices_memory() { - return get_memory_from_vec(rotated_block_indices); - } - - memory::ptr get_rotation_deltas_memory() { - auto mem = get_memory_from_vec(rotation_deltas); - auto layout = mem->get_layout(); - auto last_dim = rotation_config.per_block ? 1 : block_size; - layout.set_partial_shape(ov::PartialShape{static_cast(rotated_block_indices.size()), last_dim}); - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_rotation_trig_lut_memory() { - auto mem = get_memory_from_vec(rotation_trig_lut); - auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{max_context_len[0], k_head_size}); - - if (rotated_block_indices.empty()) { - auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{0, k_head_size}); - return test_engine.reinterpret_buffer(*mem, empty_layout); - } - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_xattention_threshold_memory() { - auto mem = get_memory_from_vec(xattention_threshold); - auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{1}); - if (xattention_threshold.empty()) { - auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{0}); - return test_engine.reinterpret_buffer(*mem, empty_layout); - } - - return test_engine.reinterpret_buffer(*mem, layout); - } - - memory::ptr get_xattention_block_size_memory() { - return get_memory_from_vec(xattention_block_size); - } - - memory::ptr get_xattention_stride_memory() { - return get_memory_from_vec(xattention_stride); - } - - memory::ptr get_sinks_memory() { - auto mem = get_memory_from_vec(sinks); - auto layout = mem->get_layout(); - layout.set_partial_shape(ov::PartialShape{1, num_heads, 1, 1}); - - if (sinks.empty()) { - auto empty_layout = mem->get_layout(); - empty_layout.set_partial_shape(ov::PartialShape{0, 0, 0, 0}); - return test_engine.reinterpret_buffer(*mem, empty_layout); - } - - return test_engine.reinterpret_buffer(*mem, layout); - } - - float get_default_scale() { - return static_cast(1.f / std::sqrt(k_head_size)); - } - -private: - template - memory::ptr get_memory_from_vec(std::vector& input_data) { - auto data_size = input_data.empty() ? 1 : input_data.size(); - auto shape = ov::PartialShape{static_cast(data_size)}; - auto layout = cldnn::layout{shape, ov::element::from(), format::bfyx}; - auto memory = test_engine.allocate_memory(layout); - - if (input_data.empty()) { - auto shape = ov::PartialShape{0}; - auto layout = cldnn::layout{shape, ov::element::from(), format::bfyx}; - return test_engine.reinterpret_buffer(*memory, layout); - } - - set_values(test_stream, memory, input_data.data(), input_data.size(), 0); - - return memory; - } - - memory::ptr get_QKV_memory(std::vector>& input_data, int num_heads, int head_size, bool skip_past_len) { - int total_tokens = 0; - for (const auto& subsequence_desc : subsequence_descs) - total_tokens += subsequence_desc.num_tokens; - - auto query_shape = ov::PartialShape{total_tokens, num_heads * head_size}; - auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; - auto memory = test_engine.allocate_memory(query_layout); - - for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { - for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - size_t input_token_offset = token_idx; - // as generated data stored in vectors includes past_len, ignore it for KV inputs - if (skip_past_len) - input_token_offset += subsequence_descs[subsequence_idx].past_len; - - ov::float16* data_ptr = input_data[subsequence_idx].data() + input_token_offset * num_heads * head_size + head_idx * head_size; - - size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; - size_t output_offset = output_token_offset * num_heads * head_size + head_idx * head_size; - - set_values(test_stream, memory, data_ptr, head_size, output_offset); - } - } - } - - return memory; - } - - template - static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { - mem_lock mem_ptr(mem, stream); - for (size_t i = 0; i < size; i++) { - mem_ptr[dst_offset + i] = vals[i]; - } - } - - static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { - const size_t total_elements_num = tokens_num * num_heads * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - return data; - } - - static std::vector generate_realistic_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t k_head_size) { - std::vector data(num_heads * tokens_num * k_head_size); - - std::mt19937 gen(1234); - std::normal_distribution dist(0.0f, 0.1f); - - for (size_t h = 0; h < num_heads; ++h) { - for (size_t t = 0; t < tokens_num; ++t) { - for (size_t d = 0; d < k_head_size; ++d) { - float val = dist(gen); - if (t > 0) - val = 0.8f * val + 0.2f * static_cast(data[h * tokens_num * k_head_size + (t - 1) * k_head_size + d]); - data[h * tokens_num * k_head_size + t * k_head_size + d] = static_cast(val); - } - } - } - - return data; - } - - static std::vector generate_rotation_deltas_data(tests::random_generator& rg, - size_t max_tokens_num, - size_t rotated_blocks_num, - size_t block_size, - bool per_block) { - const size_t total_elements_num = per_block ? rotated_blocks_num : rotated_blocks_num * block_size; - auto data = rg.generate_random_1d(total_elements_num, 0, static_cast(max_tokens_num - 1)); - - return data; - } - - static std::vector generate_rotation_trig_lut_data(tests::random_generator& rg, size_t max_tokens_num, size_t k_head_size) { - const size_t total_elements_num = max_tokens_num * k_head_size; - auto data = rg.generate_random_1d(total_elements_num, -1, 1); - - return data; - } - - static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, size_t size, bool expand_range = false) { - float min_value = std::numeric_limits::max(); - float max_value = std::numeric_limits::lowest(); - - for (size_t i = 0; i < size; i++) { - min_value = std::min((float)(data[i]), min_value); - max_value = std::max((float)(data[i]), max_value); - } - - float diff_value = 0.001; - if (max_value != min_value) - diff_value = max_value - min_value; - if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { - // compensate too small range - diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); - } - float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; - float zp = ((float)-min_value * scale) + std::numeric_limits::lowest(); - - std::vector quantized_data; - quantized_data.resize(size); - - auto convert_char_rte = [](float val) { - float rounded = std::nearbyint(val); - - if (rounded > 127.0f) { - return static_cast(127); - } else if (rounded < -128.0f) { - return static_cast(-128); - } else { - return static_cast(rounded); - } - }; - - for (size_t i = 0; i < size; i++) { - quantized_data[i] = convert_char_rte(data[i] * scale + zp); - } - - scale = 1.0f / scale; - - return std::make_tuple(quantized_data, scale, zp); - } -}; - -struct xAttentionReference { - xAttentionReference(XAttentionManager& xam) : xam(xam), test_engine(xam.test_engine), test_stream(xam.test_stream) {} - - std::pair, std::vector> get_reference(std::vector threshold) { - std::vector ref_data_output; - std::vector ref_scores_output; - - for (size_t i = 0; i < xam.subsequence_descs.size(); i++) { - const auto& subsequence_desc = xam.subsequence_descs[i]; - const auto kv_seq_len = subsequence_desc.num_tokens + subsequence_desc.past_len; - - auto key_data = xam.key_data[i]; - if (xam.rotation_config.apply_rotation) { - auto blocks_start = xam.block_indices_begins[i]; - auto blocks_end = xam.block_indices_begins[i + 1]; - - std::vector block_indices(xam.block_indices.begin() + blocks_start, xam.block_indices.begin() + blocks_end); - - for (const auto& block_idx : block_indices) { - auto it = std::find(xam.rotated_block_indices.begin(), xam.rotated_block_indices.end(), block_idx); - if (it != xam.rotated_block_indices.end()) { - int index = std::distance(xam.rotated_block_indices.begin(), it); - int subsequence_rotated_block_idx = *it - blocks_start; - - rotate_block(key_data, - xam.rotation_deltas, - xam.rotation_trig_lut, - index, - subsequence_rotated_block_idx, - xam.num_kv_heads, - xam.k_head_size, - xam.block_size, - xam.rotation_config.per_block); - } - } - } - - auto window_size = xam.has_score_aggregation ? xam.score_aggregation[i] : 1; - - auto subsequence_ref_results = run_reference(xam.query_data[i], - key_data, - xam.value_data[i], - subsequence_desc.num_tokens, - kv_seq_len, - xam.num_heads, - xam.num_kv_heads, - xam.k_head_size, - xam.v_head_size, - window_size, - xam.sliding_window_size, - xam.get_default_scale(), - static_cast(threshold[i])); - - // concatenate all subsequences into one vector - ref_data_output.insert(ref_data_output.end(), subsequence_ref_results.first.begin(), subsequence_ref_results.first.end()); - ref_scores_output.insert(ref_scores_output.end(), subsequence_ref_results.second.begin(), subsequence_ref_results.second.end()); - } - - return {ref_data_output, ref_scores_output}; - } - -private: - std::pair, std::vector> run_reference(const std::vector& query_data, - const std::vector& key_data, - const std::vector& value_data, - int num_queries, - int num_keys, - int num_heads, - int num_kv_heads, - int k_head_size, - int v_head_size, - int window_size, - int sliding_window_size, - float scale, - double threshold = 0.9, - size_t block_size = 128, - size_t stride = 16) { - auto query_shape = ov::PartialShape{1, num_queries, num_heads, k_head_size}; - auto key_shape = ov::PartialShape{1, num_keys, num_kv_heads, k_head_size}; - auto value_shape = ov::PartialShape{1, num_keys, num_kv_heads, v_head_size}; - if (num_heads != num_kv_heads) { - query_shape = ov::PartialShape{num_queries, num_kv_heads, (num_heads / num_kv_heads), k_head_size}; - key_shape = ov::PartialShape{num_keys, num_kv_heads, 1, k_head_size}; - value_shape = ov::PartialShape{num_keys, num_kv_heads, 1, v_head_size}; - } - auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; - auto key_layout = layout{key_shape, data_types::f16, format::bfyx}; - auto value_layout = layout{value_shape, data_types::f16, format::bfyx}; - auto scale_layout = cldnn::layout({1}, data_types::f16, format::bfyx); - - OPENVINO_ASSERT(query_layout.count() == query_data.size()); - OPENVINO_ASSERT(key_layout.count() == key_data.size()); - OPENVINO_ASSERT(value_layout.count() == value_data.size()); - - auto query_mem = test_engine.allocate_memory(query_layout); - auto key_mem = test_engine.allocate_memory(key_layout); - auto value_mem = test_engine.allocate_memory(value_layout); - auto scale_mem = test_engine.allocate_memory(scale_layout); - - set_values(query_mem, query_data); - set_values(key_mem, key_data); - set_values(value_mem, value_data); - set_values(scale_mem, {static_cast(scale)}); - - auto reorder_qhk_to_hqd = [&](const std::vector& src, int outer_len, int num_heads, int head_dim) { - std::vector dst(num_heads * outer_len * head_dim); - for (int h = 0; h < num_heads; ++h) { - size_t dst_h_off = static_cast(h) * outer_len * head_dim; - for (int i = 0; i < outer_len; ++i) { - size_t src_off = static_cast(i) * num_heads * head_dim + static_cast(h) * head_dim; - std::copy_n(&src[src_off], head_dim, &dst[dst_h_off + static_cast(i) * head_dim]); - } - } - return dst; - }; - - const auto query_data_3d = reorder_qhk_to_hqd(query_data, num_queries, num_heads, k_head_size); - const auto key_data_3d = reorder_qhk_to_hqd(key_data, num_keys, num_heads, k_head_size); - - ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; - if (num_queries >= static_cast(block_size)) { - const size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; - const size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; - - std::vector query_padded(num_heads * padded_q * k_head_size, 0.f); - std::vector key_padded(num_heads * padded_k * k_head_size, 0.f); - - for (int h = 0; h < num_heads; ++h) { - const auto* q_src = &query_data_3d[h * num_queries * k_head_size]; - const auto* k_src = &key_data_3d[h * num_keys * k_head_size]; - auto* q_dst = &query_padded[h * padded_q * k_head_size]; - auto* k_dst = &key_padded[h * padded_k * k_head_size]; - - std::transform(q_src, q_src + num_queries * k_head_size, q_dst, [](ov::float16 v) { - return static_cast(v); - }); - std::transform(k_src, k_src + num_keys * k_head_size, k_dst, [](ov::float16 v) { - return static_cast(v); - }); - } - ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); - retained_blocks = selector.select_blocks(query_padded.data(), - {static_cast(num_heads), padded_q, static_cast(k_head_size)}, - key_padded.data(), - {static_cast(num_heads), padded_k, static_cast(k_head_size)}); - } - auto mask_mem = get_mask_mem_combined_multi_head(num_queries, - num_keys, - num_heads, - num_kv_heads, - sliding_window_size, - retained_blocks, - static_cast(block_size)); - topology topology; - if (num_heads == num_kv_heads) { - topology.add(input_layout("query", query_layout), - input_layout("key", key_layout), - input_layout("value", value_layout), - data("mask", mask_mem), - data("scale", scale_mem), - permute("query_transposed", input_info("query"), {0, 2, 1, 3}), - permute("key_transposed", input_info("key"), {0, 2, 3, 1}), - permute("value_transposed", input_info("value"), {0, 2, 1, 3}), - gemm("qk_gemm", {input_info("query_transposed"), input_info("key_transposed")}, data_types::f16, false, false), - eltwise("scale_div", {input_info("qk_gemm"), input_info("scale")}, eltwise_mode::prod), - eltwise("eltwise", {input_info("scale_div"), input_info("mask")}, eltwise_mode::sum), - softmax("softmax", input_info("eltwise"), -1), - gemm("qkv_gemm", {input_info("softmax"), input_info("value_transposed")}, data_types::f16, false, false), - permute("qkv_gemm_transposed", input_info("qkv_gemm"), {0, 2, 1, 3}), - reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16), - reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16)); - } else { - topology.add(input_layout("query", query_layout), - input_layout("key", key_layout), - input_layout("value", value_layout), - data("mask", mask_mem), - data("scale", scale_mem), - permute("query_transposed", input_info("query"), {1, 2, 0, 3}), - permute("key_transposed", input_info("key"), {1, 2, 3, 0}), - permute("value_transposed", input_info("value"), {1, 2, 0, 3}), - gemm("qk_gemm", {input_info("query_transposed"), input_info("key_transposed")}, data_types::f16, false, false), - eltwise("scale_div", {input_info("qk_gemm"), input_info("scale")}, eltwise_mode::prod), - eltwise("eltwise", {input_info("scale_div"), input_info("mask")}, eltwise_mode::sum), - softmax("softmax", input_info("eltwise"), -1), - gemm("qkv_gemm", {input_info("softmax"), input_info("value_transposed")}, data_types::f16, false, false), - reshape("qkv_gemm_reshape", input_info("qkv_gemm"), {1, num_heads, v_head_size, num_queries}), - permute("qkv_gemm_transposed", input_info("qkv_gemm_reshape"), {0, 2, 1, 3}), - reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16), - reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16)); - } - - ExecutionConfig config = get_test_default_config(test_engine); - config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); - - network::ptr network = get_network(test_engine, topology, config, get_test_stream_ptr(), false); - network->set_input_data("query", query_mem); - network->set_input_data("key", key_mem); - network->set_input_data("value", value_mem); - - auto outputs = network->execute(); - - auto output_data_mem = outputs.at("output_data").get_memory(); - auto output_scores_mem = outputs.at("scores_data").get_memory(); - - return {get_output_data_vec(output_data_mem, num_queries, v_head_size, num_heads), - get_output_scores_vec(output_scores_mem, window_size, num_queries, num_keys, num_heads)}; - } - - std::vector get_output_scores_vec(memory::ptr scores_output, int window_size, int num_queries, int num_keys, int num_heads) { - OPENVINO_ASSERT(scores_output->count() == static_cast(num_heads * num_queries * num_keys)); - - std::vector output_scores(num_keys, 0); - mem_lock mem_ptr(scores_output, test_stream); - for (int row_idx = 0; row_idx < window_size; row_idx++) { - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - for (int score_idx = 0; score_idx < num_keys; score_idx++) { - auto scores_offset = head_idx * num_queries * num_keys + (num_queries - window_size + row_idx) * num_keys + score_idx; - output_scores[score_idx] += mem_ptr[scores_offset]; - } - } - } - - return output_scores; - } - - std::vector get_output_data_vec(memory::ptr data_output, int num_queries, int k_head_size, int num_heads) { - OPENVINO_ASSERT(data_output->count() == static_cast(num_queries * num_heads * k_head_size)); - - std::vector output_data(data_output->count()); - mem_lock mem_ptr(data_output, test_stream); - for (size_t i = 0; i < data_output->count(); i++) - output_data[i] = mem_ptr[i]; - - return output_data; - } - - memory::ptr get_mask_mem_combined_multi_head(int num_queries, - int num_keys, - int num_heads, - int num_kv_heads, - int sliding_window_size, - const ov::reference::XAttentionRetainedBlockIndicesForAllHeads& retained_blocks, - int block_size) { - OPENVINO_ASSERT(num_kv_heads > 0, "num_kv_heads must be > 0"); - OPENVINO_ASSERT(num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); - - int heads_per_kv = num_heads / num_kv_heads; - - ov::PartialShape mask_shape; - if (num_heads == num_kv_heads) { - mask_shape = ov::PartialShape{1, num_heads, num_queries, num_keys}; - } else { - mask_shape = ov::PartialShape{num_kv_heads, heads_per_kv, num_queries, num_keys}; - } - - auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; - auto mask_mem = test_engine.allocate_memory(mask_layout); - mem_lock mem_ptr(mask_mem, test_stream); - - size_t total_elems = mask_layout.count(); - for (size_t i = 0; i < total_elems; ++i) - mem_ptr[i] = std::numeric_limits::lowest(); - - for (int h = 0; h < num_heads; ++h) { - int kv_idx = (num_heads == num_kv_heads) ? 0 : (h / heads_per_kv); - int head_in_kv = (num_heads == num_kv_heads) ? h : (h % heads_per_kv); - - size_t head_offset = (static_cast(kv_idx) * heads_per_kv + static_cast(head_in_kv)) * static_cast(num_queries) * - static_cast(num_keys); - - if (retained_blocks.empty() || retained_blocks[h].empty()) { - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - ov::float16 value = ov::float16(0.f); - if (sliding_window_size == 0) { - int past_len = num_keys - num_queries + 1; - if (j >= past_len + i) - value = std::numeric_limits::lowest(); - } else { - int sliding_left = num_keys - num_queries - sliding_window_size + 1; - int past_len = num_keys - num_queries + 1; - bool is_min; - if (num_queries == num_keys) { - is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; - } else { - is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; - } - if (is_min) - value = std::numeric_limits::lowest(); - } - mem_ptr[head_offset + i * num_keys + j] = value; - } - } - continue; - } - - for (int i = 0; i < num_queries; i++) { - int left_idx = 0; - int right_idx = 0; - - if (sliding_window_size == 0) { - int past_len = num_keys - num_queries + 1; - right_idx = past_len + i - 1; - left_idx = 0; - } else { - int sliding_left = num_keys - num_queries - sliding_window_size + 1; - int past_len = num_keys - num_queries + 1; - if (num_queries == num_keys) { - left_idx = sliding_left + i; - right_idx = i; - } else { - left_idx = sliding_left + i; - right_idx = past_len + i - 1; - } - } - - left_idx = std::max(0, left_idx); - right_idx = std::min(num_keys - 1, right_idx); - - for (const auto& [q_block_idx, k_block_idx] : retained_blocks[h]) { - int q_start = q_block_idx * block_size; - int q_end = std::min(q_start + block_size, num_queries); - int k_start = k_block_idx * block_size; - int k_end = std::min(k_start + block_size, num_keys); - - if (i < q_start || i >= q_end) - continue; - - for (int j = k_start; j < k_end; j++) { - if (j >= left_idx && j <= right_idx) { - mem_ptr[head_offset + i * num_keys + j] = ov::float16(0.f); - } - } - } - } - } - - return mask_mem; - } - - void rotate_block(std::vector& cache_data, - std::vector rotation_deltas, - std::vector rotation_trig_lut_mem, - int rotated_block_idx, - int subsequence_rotated_block_idx, - int num_heads, - int k_head_size, - int block_size, - bool per_block) { - // cache_data shape: [1, num_tokens, num_heads, k_head_size] - int start_token_idx = subsequence_rotated_block_idx * block_size; - - for (int token_idx = 0; token_idx < block_size; token_idx++) { - auto rotation_deltas_offset = per_block ? rotated_block_idx : rotated_block_idx * block_size + token_idx; - auto rotation_trig_lut_idx = rotation_deltas[rotation_deltas_offset]; - for (int head_idx = 0; head_idx < num_heads; head_idx++) { - for (int k_head_size_idx = 0; k_head_size_idx < k_head_size / 2; k_head_size_idx++) { - auto input_offset = (start_token_idx + token_idx) * num_heads * k_head_size + head_idx * k_head_size + k_head_size_idx; - - auto cache_value_0 = cache_data[input_offset]; - auto cache_value_1 = cache_data[input_offset + k_head_size / 2]; - - ov::float16 rotation_value_cos = rotation_trig_lut_mem[rotation_trig_lut_idx * k_head_size + k_head_size_idx]; - ov::float16 rotation_value_sin = rotation_trig_lut_mem[rotation_trig_lut_idx * k_head_size + k_head_size_idx + k_head_size / 2]; - - cache_data[input_offset] = cache_value_0 * rotation_value_cos - cache_value_1 * rotation_value_sin; - cache_data[input_offset + k_head_size / 2] = cache_value_0 * rotation_value_sin + cache_value_1 * rotation_value_cos; - } - } - } - } - - XAttentionManager& xam; - cldnn::engine& test_engine; - cldnn::stream& test_stream; -}; - -template -struct xAttentionTest : public ::testing::TestWithParam { -public: - random_generator rg; - cldnn::engine& engine = get_test_engine(); - float tolerance = 2e-3; - - void SetUp() override { - rg.set_seed(GET_SUITE_NAME); - } - - void execute(T& p) { - XAttentionManager xam(rg, - get_test_engine(), - get_test_stream(), - p.subsequences, - p.num_heads, - p.num_kv_heads, - p.k_head_size, - p.v_head_size, - p.block_size, - p.sliding_window_size, - p.kv_cache_compression, - p.key_cache_quant_mode, - p.scores_mode == XAttentionScoresMode::SNAPKV, - p.rotation_config, - p.threshold); - - if (p.kv_cache_compression) - tolerance = 25e-3; - - auto query_mem = xam.get_query_memory(); - auto key_mem = xam.get_key_memory(); - auto value_mem = xam.get_value_memory(); - - auto key_cache_mem = xam.get_key_cache_memory(); - auto value_cache_mem = xam.get_value_cache_memory(); - - auto past_lens_mem = xam.get_past_lens_memory(); - auto subsequence_begins_mem = xam.get_subsequence_begins_memory(); - auto block_indices_mem = xam.get_block_indices_memory(); - auto block_indices_begins_mem = xam.get_block_indices_begins_memory(); - - auto scale_mem = xam.get_scale_memory(); - auto sliding_window_mem = xam.get_sliding_window_memory(); - auto alibi_mem = xam.get_alibi_memory(); - auto max_context_len_mem = xam.get_max_context_len_memory(); - - // scores calculation related memory buffers - auto score_aggregation_mem = xam.get_score_aggregation(); - - // cache rotation related memory buffers - auto rotated_block_indices_mem = xam.get_rotated_block_indices_memory(); - auto rotation_deltas_mem = xam.get_rotation_deltas_memory(); - auto rotation_trig_lut_mem = xam.get_rotation_trig_lut_memory(); - - auto xattention_threshold_mem = xam.get_xattention_threshold_memory(); - auto xattention_block_size_mem = xam.get_xattention_block_size_memory(); - auto xattention_stride_mem = xam.get_xattention_stride_memory(); - auto sinks_mem = xam.get_sinks_memory(); - - auto query_layout = query_mem->get_layout(); - auto key_layout = key_mem->get_layout(); - auto value_layout = value_mem->get_layout(); - auto key_cache_layout = key_cache_mem->get_layout(); - auto value_cache_layout = value_cache_mem->get_layout(); - auto past_lens_layout = past_lens_mem->get_layout(); - auto subsequence_begins_layout = subsequence_begins_mem->get_layout(); - auto block_indices_layout = block_indices_mem->get_layout(); - auto block_indices_begins_layout = block_indices_begins_mem->get_layout(); - auto scale_layout = scale_mem->get_layout(); - auto sliding_window_layout = sliding_window_mem->get_layout(); - auto alibi_layout = alibi_mem->get_layout(); - auto max_context_len_layout = max_context_len_mem->get_layout(); - auto score_aggregation_window_layout = score_aggregation_mem->get_layout(); - auto rotated_block_indices_layout = rotated_block_indices_mem->get_layout(); - auto rotation_deltas_layout = rotation_deltas_mem->get_layout(); - auto rotation_trig_lut_layout = rotation_trig_lut_mem->get_layout(); - auto xattention_threshold_layout = xattention_threshold_mem->get_layout(); - auto xattention_block_size_layout = xattention_block_size_mem->get_layout(); - auto xattention_stride_layout = xattention_stride_mem->get_layout(); - auto sinks_layout = sinks_mem->get_layout(); - - // make layouts dynamic - query_layout.set_partial_shape(ov::PartialShape{-1, p.num_heads * p.k_head_size}); - key_layout.set_partial_shape(ov::PartialShape{-1, p.num_kv_heads * p.k_head_size}); - value_layout.set_partial_shape(ov::PartialShape{-1, p.num_kv_heads * p.v_head_size}); - { - auto pshape = key_cache_layout.get_partial_shape(); - pshape[0] = -1; - key_cache_layout.set_partial_shape(pshape); - } - { - auto pshape = value_cache_layout.get_partial_shape(); - pshape[0] = -1; - value_cache_layout.set_partial_shape(pshape); - } - past_lens_layout.set_partial_shape(ov::PartialShape{-1}); - subsequence_begins_layout.set_partial_shape(ov::PartialShape{-1}); - block_indices_layout.set_partial_shape(ov::PartialShape{-1}); - block_indices_begins_layout.set_partial_shape(ov::PartialShape{-1}); - score_aggregation_window_layout.set_partial_shape(ov::PartialShape{-1}); - rotated_block_indices_layout.set_partial_shape(ov::PartialShape{-1}); - rotation_deltas_layout.set_partial_shape(ov::PartialShape{-1, -1}); - rotation_trig_lut_layout.set_partial_shape(ov::PartialShape{-1, p.k_head_size}); - xattention_threshold_layout.set_partial_shape(ov::PartialShape{-1}); - - if (p.dynamic_paddings) { - const auto padding_axis = 1; - const auto pad_before = p.k_head_size; - const auto pad_after = p.k_head_size * 2; - - query_layout.data_padding._dynamic_dims_mask[padding_axis] = 1; - - auto query_data_layout = query_mem->get_layout(); - auto padded_query_data_layout = query_data_layout; - padded_query_data_layout.data_padding._lower_size[padding_axis] = pad_before; - padded_query_data_layout.data_padding._upper_size[padding_axis] = pad_after; - - auto new_query_memory = get_test_engine().allocate_memory(padded_query_data_layout, false); - - mem_lock query_mem_lock(query_mem, get_test_stream()); - mem_lock new_query_mem_lock(new_query_memory, get_test_stream()); - - auto query_data_shape = query_data_layout.get_shape(); - for (size_t b = 0; b < query_data_shape[0]; b++) { - for (size_t f = 0; f < query_data_shape[1]; f++) { - auto input_offset = query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); - auto output_offset = - padded_query_data_layout.get_linear_offset(cldnn::tensor(static_cast(b), static_cast(f), 0, 0, 0, 0)); - - new_query_mem_lock[output_offset] = query_mem_lock[input_offset]; - } - } - query_mem = new_query_memory; - } - - std::vector pa_inputs = { - input_info("query"), - input_info("key"), - input_info("value"), - input_info("key_cache"), - input_info("value_cache"), - input_info("past_lens"), - input_info("subsequence_begins"), - input_info("block_indices"), - input_info("block_indices_begins"), - input_info("scale"), - input_info("sliding_window"), - input_info("alibi"), - input_info("max_context_len"), - input_info("score_aggregation_window"), - input_info("rotated_block_indices"), - input_info("rotation_deltas"), - input_info("rotation_trig_lut_modified"), - input_info("xattention_threshold"), - input_info("xattention_block_size"), - input_info("xattention_stride"), - input_info("sinks"), - }; - - auto pa_prim = paged_attention("paged_attention", pa_inputs); - - pa_prim.k_head_size = p.k_head_size; - pa_prim.v_head_size = p.v_head_size; - pa_prim.kv_heads_num = p.num_kv_heads; - pa_prim.heads_num = p.num_heads; - pa_prim.scale_val = xam.get_default_scale(); - pa_prim.has_alibi = false; - pa_prim.num_outputs = p.scores_mode == XAttentionScoresMode::DISABLED ? 1 : 2; - pa_prim.has_rotated_blocks = p.rotation_config.apply_rotation; - pa_prim.has_score_aggregation = p.scores_mode == XAttentionScoresMode::SNAPKV; - pa_prim.sliding_window = p.sliding_window_size; - pa_prim.is_key_by_channel = (p.key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL); - pa_prim.has_xattention = true; - topology topology; - - topology.add(input_layout("query", query_layout), - input_layout("key", key_layout), - input_layout("value", value_layout), - input_layout("key_cache", key_cache_layout), - input_layout("value_cache", value_cache_layout), - input_layout("past_lens", past_lens_layout), - input_layout("subsequence_begins", subsequence_begins_layout), - input_layout("block_indices", block_indices_layout), - input_layout("block_indices_begins", block_indices_begins_layout), - input_layout("scale", scale_layout), - input_layout("sliding_window", sliding_window_layout), - input_layout("alibi", alibi_layout), - input_layout("max_context_len", max_context_len_layout), - input_layout("score_aggregation_window", score_aggregation_window_layout), - pa_prim, - reorder("output_data", input_info("paged_attention", 0), format::bfyx, data_types::f16)); - - if (p.scores_mode != XAttentionScoresMode::DISABLED) { - topology.add(reorder("output_scores", input_info("paged_attention", 1), format::bfyx, data_types::f16)); - } - - { - topology.add(input_layout("rotated_block_indices", rotated_block_indices_layout)); - topology.add(input_layout("rotation_deltas", rotation_deltas_layout)); - topology.add(input_layout("rotation_trig_lut", rotation_trig_lut_layout)); - - // add dummy activation operation to simulate an empty PA `rotation_trig_lut` buffer for shapes like [0, k_head_size] - topology.add(activation("rotation_trig_lut_modified", input_info("rotation_trig_lut"), activation_func::none)); - - topology.add(input_layout("xattention_threshold", xattention_threshold_layout)); - topology.add(input_layout("xattention_block_size", xattention_block_size_layout)); - topology.add(input_layout("xattention_stride", xattention_stride_layout)); - topology.add(input_layout("sinks", sinks_layout)); - } - - ExecutionConfig config = get_test_default_config(get_test_engine()); - config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); - // FlashAttn v1 or v2? - config.set_property(ov::intel_gpu::could_use_flashattn_v2(p.disable_flashattn_v2)); - config.set_property(ov::internal::key_cache_quant_mode(p.key_cache_quant_mode)); - network::ptr network = get_network(get_test_engine(), topology, config, get_test_stream_ptr(), false); - network->set_input_data("query", query_mem); - network->set_input_data("key", key_mem); - network->set_input_data("value", value_mem); - network->set_input_data("key_cache", key_cache_mem); - network->set_input_data("value_cache", value_cache_mem); - network->set_input_data("past_lens", past_lens_mem); - network->set_input_data("subsequence_begins", subsequence_begins_mem); - network->set_input_data("block_indices", block_indices_mem); - network->set_input_data("block_indices_begins", block_indices_begins_mem); - network->set_input_data("scale", scale_mem); - network->set_input_data("sliding_window", sliding_window_mem); - network->set_input_data("alibi", alibi_mem); - network->set_input_data("max_context_len", max_context_len_mem); - network->set_input_data("score_aggregation_window", score_aggregation_mem); - network->set_input_data("rotated_block_indices", rotated_block_indices_mem); - network->set_input_data("rotation_deltas", rotation_deltas_mem); - network->set_input_data("rotation_trig_lut", rotation_trig_lut_mem); - network->set_input_data("xattention_threshold", xattention_threshold_mem); - network->set_input_data("xattention_block_size", xattention_block_size_mem); - network->set_input_data("xattention_stride", xattention_stride_mem); - network->set_input_data("sinks", sinks_mem); - - auto outputs = network->execute(); - - cldnn::memory::ptr output_data_mem = nullptr; - cldnn::memory::ptr output_scores_mem = nullptr; - - output_data_mem = outputs.at("output_data").get_memory(); - if (p.scores_mode != XAttentionScoresMode::DISABLED) { - output_scores_mem = outputs.at("output_scores").get_memory(); - } - auto ref_data = xAttentionReference(xam).get_reference(p.threshold); - compare(output_data_mem, output_scores_mem, ref_data); - } - - void compare(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) { - if (data_output_mem) { - ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); - mem_lock mem_ptr(data_output_mem, get_test_stream()); - int mismatch_count = 0; - for (size_t i = 0; i < data_output_mem->count(); i++) { - if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.first[i])) > tolerance) { - mismatch_count++; - } - } - EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.04)); - } - - if (scores_output_mem) { - ASSERT_EQ(scores_output_mem->count(), ref_data.second.size()); - mem_lock mem_ptr(scores_output_mem, get_test_stream()); - int mismatch_count = 0; - for (size_t i = 0; i < scores_output_mem->count(); i++) { - if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.second[i])) > tolerance) { - mismatch_count++; - } - } - EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.04)); - } - } - - static bool check_xattention_available() { - auto& engine = get_test_engine(); - ExecutionConfig config = get_test_default_config(engine); - if (!cldnn::check_cm_jit_support(engine, config) || !engine.get_device_info().supports_immad) { - return false; - } - - return true; - } -}; - -struct xattention_test_params { - std::vector subsequences; - int num_heads; - int num_kv_heads; - int k_head_size; - int v_head_size; - int block_size; - std::vector threshold; - int sliding_window_size; - bool kv_cache_compression; - ov::internal::CacheQuantMode key_cache_quant_mode; - bool dynamic_paddings; - XAttentionScoresMode scores_mode; - XAttentionCacheRotationDescriptor rotation_config; - bool disable_flashattn_v2; -}; - -class xattention_test : public xAttentionTest {}; -TEST_P(xattention_test, basic) { - if (!check_xattention_available()) - GTEST_SKIP(); - auto p = GetParam(); - - execute(p); -} - -const auto ENABLE_CACHE_COMPRESSION = true; -const auto DISABLE_CACHE_COMPRESSION = false; -const auto DISABLE_SCORES = XAttentionScoresMode::DISABLED; -const auto ENABLE_SCORES = XAttentionScoresMode::LAST_TOKEN; -const auto ENABLE_SCORES_SNAPKV = XAttentionScoresMode::SNAPKV; -const auto PER_BLOCK_ROTATION = XAttentionCacheRotationDescriptor{true, true}; -const auto PER_TOKEN_ROTATION = XAttentionCacheRotationDescriptor{true, false}; -const auto DISABLE_ROTATION = XAttentionCacheRotationDescriptor{false, false}; -const auto STATIC_INPUT_PAD = false; -const auto DYNAMIC_INPUT_PAD = true; -const auto ENABLE_FA_V2 = false; -const auto DISABLE_FA_V2 = true; - -INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, - xattention_test, - ::testing::ValuesIn(std::vector{ - /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, - token_size>=32, disable_mix_mode */ - xattention_test_params{ {{32, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - xattention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - xattention_test_params{ {{2048, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - - xattention_test_params{ {{1, 31}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 32}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 1023}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 1024}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 127}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 128}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 129}}, 2, 2, 64, 64, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - xattention_test_params{ {{1, 32}}, 28, 28, 128, 128, 256, {0.9}, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token -})); From e794f5b1d10aa75f43f445afe02a9b3322bd1905 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 21 Oct 2025 16:00:08 +0800 Subject: [PATCH 88/96] Fix build error (#64) --- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp | 2 +- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index b539f1bda01c4a..bcb6f3e5cdaca0 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -659,7 +659,7 @@ DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { wgs.global = {rtp->N_kq_groups * (rtp->q_stride_pad / BLOCK_WG_M) * SG_N * WALK_HQ, SG_M, desc->heads_num / WALK_HQ}; wgs.local = {SG_N, SG_M, 1}; - const uint32_t q_start_strided = N - M; + const size_t q_start_strided = N - M; OPENVINO_ASSERT(N >= M, "length of key cache must be greater or equal than query"); auto& scalars = kd.params.scalars; diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 9de30c473d7806..ca46954aa9dd64 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -743,7 +743,7 @@ struct PagedAttentionReference { } auto window_size = pam.has_score_aggregation ? pam.score_aggregation[i] : 1; - + double th = static_cast(threshold.size() == 1 ? threshold[0] : threshold[i]); auto subsequence_ref_results = run_reference(has_xattention, pam.query_data[i], key_data, @@ -757,7 +757,7 @@ struct PagedAttentionReference { window_size, pam.sliding_window_size, pam.get_default_scale(), - static_cast(threshold[i])); + th); // concatenate all subsequences into one vector ref_data_output.insert(ref_data_output.end(), @@ -1598,8 +1598,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing: })); INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test, ::testing::ValuesIn(std::vector{ - /* without scores output, static input query paddings, single sequence, disable KV cache compression, k_head_size==v_head_size, - token_size>=32, disable_mix_mode */ paged_attention_test_params{ {{32, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{2048, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token From 5ff7d321fd355359f64d191c201edd0b33c755b2 Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 21 Oct 2025 20:51:35 +0800 Subject: [PATCH 89/96] Ww/cm xattention (#65) * throw exception if k_head_size != v_head_size and has_xattn * Add more test cases --- .../src/plugin/ops/paged_attention.cpp | 3 + .../test_cases/paged_attention_gpu_test.cpp | 72 +++++++++++++++---- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 0e21f395beebd8..1a169c51059e59 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -47,6 +47,9 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared auto k_head_size = has_rt_params ? rt_info.at(k_head_size_id).as() : key_cache_ps[k_head_size_idx].get_length(); auto v_head_size = has_rt_params ? rt_info.at(v_head_size_id).as() : value_cache_ps[3].get_length(); auto kv_heads_num = has_rt_params ? rt_info.at(num_k_heads_id).as() : key_cache_ps[1].get_length(); + if (prim.has_xattention) { + OPENVINO_ASSERT(k_head_size == v_head_size); + } // WA: in some cases, the query input may have a bounded dimension // Use input shape of the input node in such cases diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index ca46954aa9dd64..0c9e369f6f4351 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -162,9 +162,9 @@ struct PagedAttentionManager { const auto& subsequence_desc = subsequence_descs[i]; max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); - query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); - key_data.push_back(generate_input_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); - value_data.push_back(generate_input_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); + query_data.push_back(generate_realistic_data(num_heads, subsequence_desc.num_tokens, k_head_size)); + key_data.push_back(generate_realistic_data(num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); + value_data.push_back(generate_realistic_data(num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); past_lens.push_back(subsequence_desc.past_len); int subsequence_start_pos = subsequence_begins[i]; @@ -791,11 +791,43 @@ struct PagedAttentionReference { auto query_shape = ov::PartialShape{1, num_queries, num_heads, k_head_size}; auto key_shape = ov::PartialShape{1, num_keys, num_kv_heads, k_head_size}; auto value_shape = ov::PartialShape{1, num_keys, num_kv_heads, v_head_size}; - if (num_heads != num_kv_heads) { + if (num_heads != num_kv_heads && !has_xattention) { query_shape = ov::PartialShape{num_queries, num_kv_heads, (num_heads / num_kv_heads), k_head_size}; key_shape = ov::PartialShape{num_keys, num_kv_heads, 1, k_head_size}; value_shape = ov::PartialShape{num_keys, num_kv_heads, 1, v_head_size}; } + bool do_gqa_expand = false; + std::vector expanded_key_data; + std::vector expanded_value_data; + if (has_xattention) { + // Grouped Query Attention + do_gqa_expand = (num_heads != num_kv_heads); + if (do_gqa_expand) { + const int group_size = num_heads / num_kv_heads; + + expanded_key_data.resize(static_cast(num_keys) * static_cast(num_heads) * static_cast(k_head_size)); + expanded_value_data.resize(static_cast(num_keys) * static_cast(num_heads) * static_cast(v_head_size)); + + for (int key_idx = 0; key_idx < num_keys; ++key_idx) { + for (int h = 0; h < num_heads; ++h) { + const int src_kv_head = h / group_size; + size_t src_key_off = (static_cast(key_idx) * static_cast(num_kv_heads) + static_cast(src_kv_head)) * static_cast(k_head_size); + size_t dst_key_off = (static_cast(key_idx) * static_cast(num_heads) + static_cast(h)) * static_cast(k_head_size); + for (int d = 0; d < k_head_size; ++d) + expanded_key_data[dst_key_off + static_cast(d)] = key_data[src_key_off + static_cast(d)]; + + size_t src_val_off = (static_cast(key_idx) * static_cast(num_kv_heads) + static_cast(src_kv_head)) * static_cast(v_head_size); + size_t dst_val_off = (static_cast(key_idx) * static_cast(num_heads) + static_cast(h)) * static_cast(v_head_size); + for (int d = 0; d < v_head_size; ++d) + expanded_value_data[dst_val_off + static_cast(d)] = value_data[src_val_off + static_cast(d)]; + } + } + + key_shape = ov::PartialShape{1, num_keys, num_heads, k_head_size}; + value_shape = ov::PartialShape{1, num_keys, num_heads, v_head_size}; + num_kv_heads = num_heads; + } + } auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; auto key_layout = layout{key_shape, data_types::f16, format::bfyx}; @@ -803,8 +835,13 @@ struct PagedAttentionReference { auto scale_layout = cldnn::layout({1}, data_types::f16, format::bfyx); OPENVINO_ASSERT(query_layout.count() == query_data.size()); - OPENVINO_ASSERT(key_layout.count() == key_data.size()); - OPENVINO_ASSERT(value_layout.count() == value_data.size()); + if (do_gqa_expand) { + OPENVINO_ASSERT(key_layout.count() == expanded_key_data.size()); + OPENVINO_ASSERT(value_layout.count() == expanded_value_data.size()); + } else { + OPENVINO_ASSERT(key_layout.count() == key_data.size()); + OPENVINO_ASSERT(value_layout.count() == value_data.size()); + } auto query_mem = test_engine.allocate_memory(query_layout); auto key_mem = test_engine.allocate_memory(key_layout); @@ -812,8 +849,13 @@ struct PagedAttentionReference { auto scale_mem = test_engine.allocate_memory(scale_layout); set_values(query_mem, query_data); - set_values(key_mem, key_data); - set_values(value_mem, value_data); + if (do_gqa_expand) { + set_values(key_mem, expanded_key_data); + set_values(value_mem, expanded_value_data); + } else { + set_values(key_mem, key_data); + set_values(value_mem, value_data); + } set_values(scale_mem, {static_cast(scale)}); ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; @@ -831,7 +873,7 @@ struct PagedAttentionReference { }; const auto query_data_3d = reorder_qhk_to_hqd(query_data, num_queries, num_heads, k_head_size); - const auto key_data_3d = reorder_qhk_to_hqd(key_data, num_keys, num_heads, k_head_size); + const auto key_data_3d = reorder_qhk_to_hqd(do_gqa_expand ? expanded_key_data : key_data, num_keys, num_heads, k_head_size); const size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; const size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; @@ -1400,6 +1442,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { mismatch_count++; } } + std::cout << "mismatch_count: " << mismatch_count << std::endl; EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.04)); } @@ -1412,6 +1455,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { mismatch_count++; } } + std::cout << "mismatch_count: " << mismatch_count << std::endl; EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.04)); } } @@ -1419,11 +1463,8 @@ struct PagedAttentionTest : public ::testing::TestWithParam { static bool check_cm_available() { auto& engine = get_test_engine(); ExecutionConfig config = get_test_default_config(engine); - if (!cldnn::check_cm_jit_support(engine, config) || !engine.get_device_info().supports_immad) { - return false; - } - - return true; + return cldnn::check_cm_jit_support(engine, config) && + (engine.get_device_info().arch == gpu_arch::xe2 || engine.get_device_info().arch == gpu_arch::xe3); } }; @@ -1601,6 +1642,9 @@ INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test, ::testing::Values paged_attention_test_params{ {{32, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{2048, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{32, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{2048, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token paged_attention_test_params{ {{1, 31}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 32}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token From 26c4f2f93735a60e55237908bf8f380955e8250e Mon Sep 17 00:00:00 2001 From: Wang Wangwang Date: Tue, 21 Oct 2025 21:02:40 +0800 Subject: [PATCH 90/96] Remove debug messages (#66) --- .../tests/unit/test_cases/paged_attention_gpu_test.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 0c9e369f6f4351..536b844417fab2 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1442,7 +1442,6 @@ struct PagedAttentionTest : public ::testing::TestWithParam { mismatch_count++; } } - std::cout << "mismatch_count: " << mismatch_count << std::endl; EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.04)); } @@ -1455,7 +1454,6 @@ struct PagedAttentionTest : public ::testing::TestWithParam { mismatch_count++; } } - std::cout << "mismatch_count: " << mismatch_count << std::endl; EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.04)); } } From 1ec3dfd99cc442b4bcbc7b2e0c4caa6c2f48a766 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 22 Oct 2025 10:09:08 +0800 Subject: [PATCH 91/96] fix the place to check kvcache precision --- src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index fcfb210d35cc48..0b1efeea33e140 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -561,7 +561,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // v: [num_blocks, num_kv_heads, block_size(256), head_size] ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; const auto kv_cache_precision = config.get_kv_cache_precision(); - OPENVINO_ASSERT(kv_cache_precision != ov::element::dynamic, "[GPU] kv_cache precision should be specified."); kv_cache_config.keyCachePrecision = kv_cache_precision; kv_cache_config.valueCachePrecision = kv_cache_precision; kv_cache_config.inferencePrecision = infer_precision; @@ -594,6 +593,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { const size_t group_num, int64_t& head_size, int64_t& block_size) { + OPENVINO_ASSERT(precision != ov::element::dynamic, "[GPU] kv_cache precision should be specified."); if (bychannel) { // TODO: need to handle group size != block size case if (precision == ov::element::i8 || precision == ov::element::u8) { From a6e4bbb6aa39cc8fb4613b76baf0e33d0b0f84f1 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 22 Oct 2025 10:31:09 +0800 Subject: [PATCH 92/96] useless code cleanup. --- .../src/graph/impls/cm/pa_kv_cache_update_ref.cm | 12 ------------ .../src/graph/impls/cm/pa_multi_token.cm | 5 +---- .../src/graph/impls/cm/pa_single_token.cm | 16 ---------------- .../impls/cm/pa_single_token_finalization.cm | 1 - .../src/graph/impls/cm/xattn_gemm_qk.cm | 1 - .../src/graph/impls/cm/xattn_post_proc.cm | 15 --------------- .../intel_gpu/src/graph/paged_attention.cpp | 15 +++++++++------ 7 files changed, 10 insertions(+), 55 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm index fc7563e567ab3c..b019d217bc99c4 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -25,7 +25,6 @@ constexpr uint wg_size = WG_SIZE; #define REG_K 16 -// extern "C" _GENX_MAIN_ void pa_kv_cache_update( extern "C" _GENX_MAIN_ void KERNEL_NAME( const half* key [[type("svmptr_t")]], const half* value [[type("svmptr_t")]], @@ -64,9 +63,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( const auto wg_local_id = cm_local_id(2); const auto local_size = cm_local_size(2); - // static_assert(local_size == wg_size); - - // const uint token_idx = wg_id * local_size + wg_local_id; const uint token_idx = cm_global_id(2); // token_idx -> subsequence_idx @@ -79,8 +75,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } } - // printf("wg:%d.%d, token_idx: %d, subsequence_idx: %d\n", wg_id, wg_local_id, token_idx, subsequence_idx); - const uint subsequence_begin_idx = subsequence_begins[subsequence_idx]; const uint past_len = past_lens[subsequence_idx]; const uint current_block_idx = (past_len + token_idx - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; @@ -130,12 +124,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( uint value_out_offset = block_v_base_offset + token_start_pos * V_HEAD_SIZE; uint value_in_offset = token_idx * value_pitch + head_idx * V_HEAD_SIZE + value_offset; - //if(token_idx==0 && head_idx==0) - //{ - // printf("value_pitch = %d, value_in_offset: %d, value_out_offset: %d,V_HEAD_SIZE = %d, ADJUSTED_V_HEAD_SIZE = %d\n", - // value_pitch, value_in_offset, value_out_offset, V_HEAD_SIZE, ADJUSTED_V_HEAD_SIZE); - //} - vector value_data; value_data.format() = cm_ptr_load((int*)value, value_in_offset * (int)sizeof(half)); #if KV_CACHE_COMPRESSION_PER_TOKEN diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm index f9d10d85b4fdf5..4410060e47ca56 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm @@ -24,7 +24,6 @@ namespace KERNEL_NAME { #define USE_LSC 0 #endif -//extern "C" _GENX_MAIN_ void pa_multi_token( extern "C" _GENX_MAIN_ void KERNEL_NAME( //query [q_len, num_heads, S] half* query [[type("svmptr_t")]], @@ -62,7 +61,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #if CMPA_KVCACHE_U8 constexpr uint K_SLM_SIZE = (4*kv_step * head_size * sizeof(half)); constexpr uint V_SLM_SIZE = (4*kv_step * head_size * sizeof(half)); - constexpr uint Q_SLM_SIZE = 0;//(q_step * head_size * sizeof(half)) * local_size; + constexpr uint Q_SLM_SIZE = 0; cm_slm_init(K_SLM_SIZE + V_SLM_SIZE + Q_SLM_SIZE); @@ -121,7 +120,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( kv_stop = (wg_id + 1) * wg_seq_len + past_q_lens; if (kv_stop > kv_seq_len) kv_stop = kv_seq_len; } - // printf("###########wg:%d.%d q: %d, +%d kv: %d, +%d, kvstop:%d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop); //Q/O[B, L, H, S] uint q_offset = (q_start_sg*num_heads + h)*head_size; @@ -134,7 +132,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE; block_mask_base = sparse_block_mask + (h * num_q_blocks + q_start_block) * num_k_blocks; wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id) * num_k_blocks; - // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, num_q_blocks, num_k_blocks, sparse_block_mask, block_mask_base); } #endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm index 058e4175d7091e..5db4e9a5dc3b8e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -58,7 +58,6 @@ inline void prepack_to_VNNI_W2(matrix_ref input, matrix_ref rS = 0; #else @@ -219,14 +213,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( } #else cm_load(Kt.format(), b2dK.set_block_y(kv_pos)); - // Not need clean K cache: 1) col write will lead to huge perf drop; 2) softmax will clear unused scores - // if(kv_pos_end < kv_pos + KV_STEP) { - // auto KmatRef = Kt.format(); - // uint valid_cols = kv_pos_end - kv_pos; - // uint valid_cols_vnni = valid_cols * 2; - // for (int r = valid_cols_vnni; r < KV_STEP * 2; r++) - // KmatRef.select(0,r) = 0.0f; - // } #endif #else matrix temp; @@ -285,7 +271,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( auto rS_slice = rS[qi].format(); rS_slice = cm_mul(rS_slice, (float)SCALE_FACTOR); // convert scale_factor into (float), or it will be promoted to double - // printf("leftover_size = %d, leftover_aligned_size = %d, XE_ARCH = %d, KV_PARTITION_STEP_NUM * REG_N = %d\n", leftover_size, leftover_aligned_size, XE_ARCH, KV_PARTITION_STEP_NUM * REG_N); if(leftover_size > 0) { auto Svec = rS_slice.format(); for(int i = leftover_size; i < KV_PARTITION_STEP_NUM * REG_N; i++){ @@ -374,7 +359,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME( #endif #pragma unroll for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { - // uint kv_offset_y = kv_pos; #if KV_CACHE_COMPRESSION vector temp_scale = scale_vec.select(kv_pos); vector temp_zp = zp_vec.select(kv_pos); diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm index dcf0a17327e201..501ffe48e3ec69 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -16,7 +16,6 @@ //cm_sdpa_2nd_reduce extern "C" _GENX_MAIN_ void KERNEL_NAME( -// extern "C" _GENX_MAIN_ void cm_sdpa_2nd_reduce( float* input [[type("svmptr_t")]], half* output [[type("svmptr_t")]], float* lse [[type("svmptr_t")]], diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm index e3f23fd9ec65e2..24fb0ab4864b69 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm @@ -59,7 +59,6 @@ CM_INLINE void get_mn(uint& id_wg_m, uint& id_wg_n, uint M, uint N, int slice_no } } -// _GENX_MAIN_ void gemm_qk( extern "C" _GENX_MAIN_ void KERNEL_NAME( svmptr_t key_cache ATTR, svmptr_t query ATTR, diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm index 62410533eac194..0bbab306c2f5ef 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm @@ -17,12 +17,6 @@ namespace KERNEL_NAME { #include "estimate.hpp" -// NOTE: q_stride_pad / TOKEN_IN_BLOCK <= q_block_pad, case for q_stride_pad / TOKEN_IN_BLOCK < q_block_pad: -// query = 256*16+1, then -// q_stride_pad = 256 -// q_stride_pad / TOKEN_IN_BLOCK = 32 -// q_block_pad = div_up(256*16+1, 128) = 33 -// _GENX_MAIN_ void post_proc_mask( extern "C" _GENX_MAIN_ void KERNEL_NAME(svmptr_t block_mask ATTR, svmptr_t merged_block_mask ATTR, uint q_stride_pad, uint q_block_pad, uint k_block_pad) { // block_mask: [b, hq, q_block_pad, k_block_pad] // merged_block_mask: [b, hq, q_block_pad/MERGED_Q_NUM, k_block_pad] @@ -37,15 +31,6 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(svmptr_t block_mask ATTR, svmptr_t merge merged_block_mask += m_mereged * k_block_pad; block_mask += m_mereged * MERGED_Q_NUM * k_block_pad; vector one = 1; - // q is not inside mask, aka q=1~15 which is less than param `stride` - //for (int i = 0; i < MERGED_Q_NUM; i++) { - // auto q_stride_cur = m_mereged * MERGED_Q_NUM + i; - // if (q_stride_cur >= q_stride_pad / TOKEN_IN_BLOCK && q_stride_cur < q_block_pad) { - // for (int j = 0; j < k_block_pad; j += 32) { - // cm_ptr_store((int*)block_mask, j + i * k_block_pad, one.format()); - // } - // } - //} for (int j = 0; j < k_block_pad; j += 32) { vector new_mask = cm_ptr_load((int*)block_mask, j).format(); for (int i = 1; i < MERGED_Q_NUM; i++) { diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 89c11ee000cfda..8ee2fa1ba6bd7d 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -57,7 +57,8 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no const auto& subseq_begins_ps = impl_param.get_input_layout(PagedAttentionInputIdx::SUBSEQUENCE_BEGINS).get_partial_shape(); bool valid_subseq_count = subseq_begins_ps.is_dynamic() || (subseq_begins_ps[0].get_length() == static_cast(2)); - OPENVINO_ASSERT(valid_subseq_count, "[GPU] Unexpected sub sequences count for XAttention. Got ", subseq_begins_ps[0].get_length() - 1); + if(valid_subseq_count) + OPENVINO_THROW("[GPU] Unexpected sub sequences count for XAttention. Got ", subseq_begins_ps[0].get_length() - 1); } std::vector output_layouts{ data_layout }; @@ -120,11 +121,11 @@ paged_attention_inst::typed_primitive_inst(network& network, const paged_attenti : parent(network, node) { const auto desc = node.get_primitive(); - // const auto k_head_size = desc->k_head_size; - // const auto v_head_size = desc->v_head_size; + const auto k_head_size = desc->k_head_size; + const auto v_head_size = desc->v_head_size; const auto heads_num = desc->heads_num; const auto kv_heads_num = desc->kv_heads_num; - // const auto pa_block_size = desc->block_size; + const auto pa_block_size = desc->block_size; if (desc->has_alibi) { const auto alibi_input_idx = 11; @@ -133,7 +134,9 @@ paged_attention_inst::typed_primitive_inst(network& network, const paged_attenti } OPENVINO_ASSERT(heads_num % kv_heads_num == 0); - // OPENVINO_ASSERT(k_head_size % pa_block_size == 0); - // OPENVINO_ASSERT(v_head_size % pa_block_size == 0); + if (!desc->has_xattention) { + OPENVINO_ASSERT(k_head_size % pa_block_size == 0); + OPENVINO_ASSERT(v_head_size % pa_block_size == 0); + } } } // namespace cldnn From bdf2e89f4a6f58893dc678f97fe36a86ba9091e3 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 22 Oct 2025 11:39:58 +0800 Subject: [PATCH 93/96] fix lint error --- src/plugins/intel_gpu/src/graph/paged_attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 8ee2fa1ba6bd7d..fd084f76c37176 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -57,7 +57,7 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no const auto& subseq_begins_ps = impl_param.get_input_layout(PagedAttentionInputIdx::SUBSEQUENCE_BEGINS).get_partial_shape(); bool valid_subseq_count = subseq_begins_ps.is_dynamic() || (subseq_begins_ps[0].get_length() == static_cast(2)); - if(valid_subseq_count) + if (valid_subseq_count) OPENVINO_THROW("[GPU] Unexpected sub sequences count for XAttention. Got ", subseq_begins_ps[0].get_length() - 1); } From 8ba831adfe71f7ae2d517308e63c030c73efb1fb Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Thu, 23 Oct 2025 16:11:26 +0800 Subject: [PATCH 94/96] fix throw check --- src/plugins/intel_gpu/src/graph/paged_attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index fd084f76c37176..1b0bc58d79471a 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -57,7 +57,7 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no const auto& subseq_begins_ps = impl_param.get_input_layout(PagedAttentionInputIdx::SUBSEQUENCE_BEGINS).get_partial_shape(); bool valid_subseq_count = subseq_begins_ps.is_dynamic() || (subseq_begins_ps[0].get_length() == static_cast(2)); - if (valid_subseq_count) + if (!valid_subseq_count) OPENVINO_THROW("[GPU] Unexpected sub sequences count for XAttention. Got ", subseq_begins_ps[0].get_length() - 1); } From 5201cdf4ba2037585f8a9bd5cc3b5c96f55b9d63 Mon Sep 17 00:00:00 2001 From: gta Date: Fri, 24 Oct 2025 01:18:02 +0000 Subject: [PATCH 95/96] add allow_bypass_xattn --- .../include/intel_gpu/runtime/internal_properties.hpp | 1 + .../intel_gpu/include/intel_gpu/runtime/options.inl | 2 +- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp | 8 ++++++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp index eb55e68ef6a609..3d26851e9ddb77 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp @@ -177,6 +177,7 @@ static constexpr Property dynamic_quantization static constexpr Property validate_output_buffer{"GPU_VALIDATE_OUTPUT_BUFFER"}; static constexpr Property mem_pool_util_threshold{"GPU_MEM_POOL_UTIL_THRESHOLD"}; static constexpr Property dump_src_after_exec{"GPU_DUMP_SRC_TENSORS_AFTER_EXEC"}; +static constexpr Property allow_bypass_xattn{"GPU_ALLOW_BYPASS_XATTN_EXEC"}; } // namespace ov::intel_gpu namespace cldnn { diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl index be2bf896326e62..7e8cd89e551c43 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl @@ -58,7 +58,7 @@ OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, could_use_flashattn_v2, true, " OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, dynamic_quantization_threshold, 64, "Apply dynamic quantization only when batch size is larger than this value in OneDNN") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, weightless_attr, nullptr, "Used to configure ov::WeightlessCacheAttribute for constants that are not loaded from a .bin file. This typically applies to non-IR inputs (e.g., ORT)") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, dynamic_quantization_precomputed_reduction, true, "Precompute reduction of activation for faster dynamic quantization in case of asymmetric weight") - +OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, allow_bypass_xattn, true, "Allow bypass xattn execution if threshold >= 1.0.") OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, help, false, "Print help message for all config options") OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, verbose, 0, "Enable logging for debugging purposes. The higher value the more verbose output. 0 - Disabled, 4 - Maximum verbosity") diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index bcb6f3e5cdaca0..8c15e790448f05 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -120,8 +120,12 @@ float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_idx) { // either threshold is larger than 1.0, or, q_len is too small // to compute xattn block_mask. bool bypass_xattn(const kernel_impl_params& params) { - auto xattn_thresh = get_xattn_thresh(params); - bool bypass = xattn_thresh >= 1.0; + bool bypass = false; + bool allow_bypass = params.get_program().get_config().get_allow_bypass_xattn(); + if (allow_bypass) { + auto xattn_thresh = get_xattn_thresh(params); + bypass = xattn_thresh >= 1.0; + } auto q_len = params.output_layouts[0].get_shape()[0]; bypass |= q_len < static_cast(STRIDE); //# will slient drop the tails which is less than `stride` From 78259609a5dd3cff13e048b5c2d4d3aa8d091e81 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Fri, 24 Oct 2025 07:03:11 +0000 Subject: [PATCH 96/96] fix rt_params q_block_pad_merged not assigned issue --- src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp | 3 ++- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp | 4 +--- .../intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp index 4de0b8ce5efefd..f2ddcfd67de1f6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -71,6 +71,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { auto rt_params = static_cast(m_rt_params.get()); rt_params->q_block_pad = q_block_pad; rt_params->k_block_pad = k_block_pad; + rt_params->q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); const size_t head_size = desc->k_head_size; @@ -199,7 +200,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM { GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: count_kq_max_wg=" << count_kq_max_wg * 4 << " count_kq_exp_partial_sum=" << count_kq_exp_partial_sum * 4 << " count_elements_mask=" << count_elements_mask * 1 - << " count_elements_mask_merged=" << count_kq_exp_partial_sum * 1 << std::endl; + << " count_elements_mask_merged=" << count_elements_mask_merged * 1 << std::endl; } } diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp index 8c15e790448f05..ffcea9c6f51eba 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -781,9 +781,7 @@ DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { auto& wgs = kd.params.workGroups; - const uint32_t q_block_pad_merged = ceil_div(rtp->q_block_pad, MERGED_Q_NUM); - - wgs.global = {q_block_pad_merged, desc->heads_num, 1}; + wgs.global = {rtp->q_block_pad_merged, desc->heads_num, 1}; wgs.local = {1, 1, 1}; auto& scalars = kd.params.scalars; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp index d4f6952cb2d35a..c630269d6f8696 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -51,7 +51,7 @@ struct PagedAttentionRuntimeParams : public ImplRuntimeParams { size_t q_block_pad; size_t k_block_pad; size_t q_stride_pad; - uint32_t q_block_pad_merged; + size_t q_block_pad_merged; size_t N_kq_groups; size_t M; size_t N;