diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp index 1b0b4b77933ee1..f96eff2c5ee4f2 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp @@ -109,7 +109,7 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co value_cache->set_element_type(value_cache_precision); bool status = false; if (pa_op->get_rt_info().count("num_k_heads") && pa_op->get_rt_info().count("k_head_size") && - pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("num_v_heads")) { + pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("v_head_size")) { const auto key_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_k_heads"].as(), pa_op->get_rt_info()["k_head_size"].as(), m_config.keyCacheBlockSize, diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index 14f22ea4f0032a..464c822f7c34f6 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -38,6 +38,7 @@ struct paged_attention : public primitive_base { }; static constexpr size_t block_size = 16; + static constexpr size_t block_size_xattn = 256; paged_attention() : primitive_base("", {}) {} diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp index 6fc03be44a88b8..a265f05337089a 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp @@ -178,6 +178,8 @@ static constexpr Property, ov::PropertyMutability::RW> static constexpr Property could_use_flashattn_v2{"GPU_COULD_USE_FLASHATTN_V2"}; static constexpr Property validate_output_buffer{"GPU_VALIDATE_OUTPUT_BUFFER"}; static constexpr Property mem_pool_util_threshold{"GPU_MEM_POOL_UTIL_THRESHOLD"}; +static constexpr Property dump_src_after_exec{"GPU_DUMP_SRC_TENSORS_AFTER_EXEC"}; +static constexpr Property allow_bypass_xattn{"GPU_ALLOW_BYPASS_XATTN_EXEC"}; } // namespace ov::intel_gpu namespace cldnn { diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl index 18c38dd836554b..4f0bfe77b83c7b 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl @@ -57,6 +57,7 @@ OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, asym_dynamic_quantization, fals OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, could_use_flashattn_v2, true, "Enable/Disable SDPA primitive executing with FlashAttenV2 online softmax tricks.") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, dynamic_quantization_threshold, 64, "Apply dynamic quantization only when batch size is larger than this value in OneDNN") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, dynamic_quantization_precomputed_reduction, true, "Precompute reduction of activation for faster dynamic quantization in case of asymmetric weight") +OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, allow_bypass_xattn, true, "Allow bypass xattn execution if threshold >= 1.0.") OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, weightless_attr, nullptr, "Used to configure ov::WeightlessCacheAttribute for constants that are not loaded from a .bin file. This typically applies to non-IR inputs (e.g., ORT)") @@ -81,6 +82,7 @@ OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_layer_names, std::vector OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_memory_pool_path, "", "Save csv file with memory pool info to specified folder") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_memory_pool, false, "Enable verbose output for memory pool") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_iterations, std::set{}, "Space separated list of iterations where other dump options should be enabled") +OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, dump_src_after_exec, false, "Enable source data dump after layer execution. Useful for capturing updated states in stateful models.") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, host_time_profiling, 0, "Measure and print host time spent from the beginning of the infer until all host work is done and plugin is ready to block thread on the final clFinish() call") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_async_compilation, false, "Disable feature that allows to asynchronously prepare static-shaped implementations for the primitives with shape-agnostic kernels selected during compilation") OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_runtime_buffer_fusing, false, "Disable runtime inplace optimizations for operations like concat and crop") diff --git a/src/plugins/intel_gpu/src/graph/debug_helper.cpp b/src/plugins/intel_gpu/src/graph/debug_helper.cpp index 01e75cc7c0f2ed..fe986f18106fc4 100644 --- a/src/plugins/intel_gpu/src/graph/debug_helper.cpp +++ b/src/plugins/intel_gpu/src/graph/debug_helper.cpp @@ -264,7 +264,7 @@ void log_memory_to_file(memory::ptr mem, layout data_layout, stream& stream, std dump(actual_mem, stream, file_stream, dump_raw); else if (mem_dt == cldnn::data_types::u8) dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::u8) + else if (mem_dt == cldnn::data_types::boolean) dump(actual_mem, stream, file_stream, dump_raw); else if (mem_dt == cldnn::data_types::i4 || mem_dt == cldnn::data_types::u4) dump_i4u4(mem_dt, actual_mem, stream, file_stream, dump_raw); @@ -550,7 +550,7 @@ NodeDebugHelper::~NodeDebugHelper() { for (size_t i = 0; i < m_inst.get_intermediates_memories().size(); i++) { std::string name = get_file_prefix() + "_intermediates_" + std::to_string(i); auto output_mem = m_inst.get_intermediates_memories()[i]; - if (output_mem == nullptr) { + if (output_mem == nullptr || output_mem->size() == 0) { GPU_DEBUG_COUT << " intermediates_mem is nullptr. Nothing to dump." << std::endl; continue; } @@ -572,6 +572,35 @@ NodeDebugHelper::~NodeDebugHelper() { log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw); } } + + if (config.get_dump_src_after_exec()) { + for (size_t i = 0; i < m_inst.inputs_memory_count(); i++) { + std::string name = get_file_prefix() + "_updated_src_" + std::to_string(i); + auto output_mem = m_inst.input_memory_ptr(i); + if (output_mem == nullptr) { + GPU_DEBUG_COUT << " updated_input_mem is nullptr. Nothing to dump." << std::endl; + continue; + } + + auto& output_layout = m_inst.get_input_layout(i); + if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary) { + // Binary dump : raw + auto filename = get_file_path_for_binary_dump(output_layout, name, config.get_dump_tensors_path()); + + mem_lock lock(output_mem, m_stream); + ov::util::save_binary(filename, lock.data(), output_mem->size()); + GPU_DEBUG_COUT << " Dump layer dst : " << layer_name << " to " << filename << std::endl; + debug_str_for_bin_load += (filename + ","); + } else { + const bool dump_raw = config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::text_raw; + GPU_DEBUG_COUT << " Dump " << (dump_raw ? "raw " : "") << name << std::endl; + auto filename = config.get_dump_tensors_path() + get_name_for_dump(name) + ".txt"; + // Text dump + log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw); + } + } + } + if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary && m_inst.is_input()) { debug_str_for_bin_load[debug_str_for_bin_load.size()-1] = '\"'; GPU_DEBUG_COUT << debug_str_for_bin_load << std::endl;; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp new file mode 100644 index 00000000000000..325cd25fc5ba59 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp @@ -0,0 +1,415 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#include +#include + +//# CM-compiler is C++17 +static_assert(__cplusplus >= 201703L); +//# static_assert(__cplusplus >= 202002L); +//# static_assert(__cplusplus >= 202302L); + +#define SystolicDepth 8 +#define RepeatCount 8 +#define VNNI_WIDTH 2 +#define REG_K (SystolicDepth * VNNI_WIDTH) +#define REG_M RepeatCount +//REG_N +// Xe1: 8 +// Xe2: 16 +#define REG_N (CM_GRF_WIDTH/32) + +#define kv_step REG_K +#define q_step REG_N + +constexpr float scale_factor = CMFLA_SCALE_FACTOR; + +static_assert(q_step == 16 || q_step == 8); +static_assert(kv_step == 16); +static_assert(CM_HAS_DPAS); + +#define DEBUG_SHOW 1 +#if !DEBUG_SHOW +template +void show(const matrix mat, bool isfloat=true) { +} +#else +template +void show(const matrix mat, bool isfloat=true) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + if (isfloat) + printf("%8.4f,", mat[m][n]); + else + printf("%8d,", mat[m][n]); + + } + printf("],\n"); + } + printf("]\n"); +} +#endif +template +CM_INLINE void Transpose_16x16(matrix_ref in, + matrix_ref out) { + matrix bBuf; + bBuf.row(0) = in.template select<4, 1, 4, 4>(0, 0); // 0,4,8,c + bBuf.row(1) = in.template select<4, 1, 4, 4>(4, 0); // 0,4,8,c + bBuf.row(2) = in.template select<4, 1, 4, 4>(8, 0); // 0,4,8,c + bBuf.row(3) = in.template select<4, 1, 4, 4>(12, 0); // 0,4,8,c + bBuf.row(4) = in.template select<4, 1, 4, 4>(0, 1); // 1,5,9,d + bBuf.row(5) = in.template select<4, 1, 4, 4>(4, 1); // 1,5,9,d + bBuf.row(6) = in.template select<4, 1, 4, 4>(8, 1); // 1,5,9,d + bBuf.row(7) = in.template select<4, 1, 4, 4>(12, 1); // 1,5,9,d + bBuf.row(8) = in.template select<4, 1, 4, 4>(0, 2); // 2,6,a,e + bBuf.row(9) = in.template select<4, 1, 4, 4>(4, 2); // 2,6,a,e + bBuf.row(10) = in.template select<4, 1, 4, 4>(8, 2); // 2,6,a,e + bBuf.row(11) = in.template select<4, 1, 4, 4>(12, 2); // 2,6,a,e + bBuf.row(12) = in.template select<4, 1, 4, 4>(0, 3); // 3,7,b,f + bBuf.row(13) = in.template select<4, 1, 4, 4>(4, 3); // 3,7,b,f + bBuf.row(14) = in.template select<4, 1, 4, 4>(8, 3); // 3,7,b,f + bBuf.row(15) = in.template select<4, 1, 4, 4>(12, 3); // 3,7,b,f + + out.row(0) = bBuf.template select<4, 1, 4, 4>(0, 0); // 0 + out.row(1) = bBuf.template select<4, 1, 4, 4>(4, 0); // 1 + out.row(2) = bBuf.template select<4, 1, 4, 4>(8, 0); // 2 + out.row(3) = bBuf.template select<4, 1, 4, 4>(12, 0); // 3 + out.row(4) = bBuf.template select<4, 1, 4, 4>(0, 1); // 4 + out.row(5) = bBuf.template select<4, 1, 4, 4>(4, 1); // 5 + out.row(6) = bBuf.template select<4, 1, 4, 4>(8, 1); // 6 + out.row(7) = bBuf.template select<4, 1, 4, 4>(12, 1); // 7 + out.row(8) = bBuf.template select<4, 1, 4, 4>(0, 2); // 8 + out.row(9) = bBuf.template select<4, 1, 4, 4>(4, 2); // 9 + out.row(10) = bBuf.template select<4, 1, 4, 4>(8, 2); // a + out.row(11) = bBuf.template select<4, 1, 4, 4>(12, 2); // b + out.row(12) = bBuf.template select<4, 1, 4, 4>(0, 3); // c + out.row(13) = bBuf.template select<4, 1, 4, 4>(4, 3); // d + out.row(14) = bBuf.template select<4, 1, 4, 4>(8, 3); // e + out.row(15) = bBuf.template select<4, 1, 4, 4>(12, 3); // f +} + +template +CM_INLINE void Transpose_8x8(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); + temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); + temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); + temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); + temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); + temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); + temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); + temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); + + out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); + out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); + out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); + out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); + out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); + out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); + out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); + out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); +} + +// function templates cannot be partially specialized; use overloading to achieve the same effect +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in, out); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_16x16(in, out); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); + Transpose_8x8(in.select<8, 1, 8, 1>(8,0), out.select<8, 1, 8, 1>(0,8)); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); + Transpose_8x8(in.select<8, 1, 8, 1>(0,8), out.select<8, 1, 8, 1>(8,0)); +} + +template +CM_INLINE void slm_read_2d(matrix_ref out, uint slm, int offset) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_slm_block_read(slm, GENX_DWALIGNED, offset + i*n_stride*sizeof(T), out.row(i)); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_svm_block_read(base + i * pitch, out[i]); + } +} + +template +CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + out.row(i).format() = cm_load(base, offset + i * pitch); + } +} + +template +CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + out.row(i).format() = cm_load(base, offset + i * pitch); + } +} + +template +CM_INLINE void cm_store_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_store(base, offset + i * pitch, out.row(i).format()); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, vector_ref offsets) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_svm_block_read(base + offsets[i], out[i]); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch, n_rows--) { + if (n_rows > 0) cm_svm_block_read(base, out[i]); + } +} + +template +CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch) { + cm_svm_block_write(base, out[i]); + } +} + +template +CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch) { + if (i < n_rows) cm_svm_block_write(base, out[i]); + } +} + +CM_INLINE uint64_t get_clock() { + auto clk = cm_clock(); + return ((uint64_t)clk[1]) << 32 | clk[0]; +} + + +template +inline matrix ugemm_KQ(uint slm_K, matrix_ref Qt, uint slm_offset = 0) { + matrix St; + constexpr int num_K = _kv_step/REG_M; + auto St2 = St.format(); + + matrix Kmat; + cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + // if (cm_local_id(2) == 3 && cm_group_id(2) == 0) { + // show(Kmat.format()); + // } + #pragma unroll + for(int k = 0; k < num_K; k++) + St2.row(k) = cm_dpas(0, Qt[0].format(), Kmat[k].format()); + + #pragma unroll + for(int ri = 1; ri < num_Qt; ri++) { + cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); + #pragma unroll + for(int k = 0; k < num_K; k++) { + St2.row(k) = cm_dpas(St2.row(k), Qt[ri].format(), Kmat[k].format()); + } + } + return St; +} + +template +inline void ugemm_PV0(uint slm_V, matrix_ref P, matrix_ref rO, uint slm_offset = 0) { + constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; + + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri = 0; k < _head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); + // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { + // show(Vmat.format()); + // } + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + 0, + Vmat.format(), + P2.row(p).format()); + //show(rO[ri + p].format()); + } + } +} + +template +inline void ugemm_PV1(uint slm_V, matrix_ref P, vector_ref max_comp, + matrix_ref rO, uint slm_offset = 0) { + constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri=0; k < _head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); + // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { + // show(Vmat.format()); + // } + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < REG_M; r++) + cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); + } + + //show(rO[ri].format()); + + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + rO[ri + p].format(), + Vmat.format(), + P2.row(p).format()); + //if (kv_pos == args_verbose) show(rO[ri + p].format()); + } + // if (kv_pos == args_verbose) show(cur_O.format()); + } +} + +template +vector online_softmax_update(matrix_ref St, vector_ref cur_max, vector_ref cur_sum) { + vector new_max_t; + new_max_t = cm_max(St[0], St[1]); + for(int r = 2; r < St.n_rows(); r++) new_max_t = cm_max(new_max_t, St[r]); + new_max_t = cm_max(new_max_t, cur_max); + + // Pt = torch.exp(St - new_max) + constexpr float log2e = 1.4426950408889634f; + for(int r = 0; r < St.n_rows(); r++) St[r] = cm_exp((St[r] - new_max_t)*log2e); + + vector row_sum_t; + row_sum_t = cm_add(St[0], St[1]); + for(int r = 2; r < St.n_rows(); r++) row_sum_t = cm_add(row_sum_t, St[r]); + + vector max_comp; + max_comp = cm_exp((cur_max - new_max_t)*log2e); + cur_sum = cm_mul(cur_sum, max_comp); + cur_sum = cm_add(cur_sum, row_sum_t); + cur_max = new_max_t; + return max_comp; +} + +#ifdef CM_HAS_LSC_UNTYPED_2D + #define cm_load_normal cm_load + #define cm_load_transpose cm_load + #define cm_load_vnni cm_load + #define cm_store_normal cm_store +#else + // simulation of LSC API using SVM API + template + inline void cm_load_normal(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, Res.select(i*BlockW)); + } + } + + template + inline void cm_load_transpose(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + matrix temp; + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, temp[i]); + } + Transpose2DMatrix(temp, Res.format()); + } + + // in VNNI case, NBlocks is increasing along X dimension (increase cache-line usage) + template + inline void cm_load_vnni(vector_ref Res, const lsc::block_2d_desc &Desc, int16_t Pred = 1) { + static_assert(NBlocks == 1 || NBlocks == 2); + // each block must be a full XMX B matrix + static_assert(BlockH == REG_K); + static_assert(BlockW == REG_N); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + matrix temp; + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_read(base + i * pitch, temp[i]); + } + + auto out_vnni = Res.format(); + #pragma unroll + for(int i = 0; i < NBlocks; i ++) { + out_vnni.select(i*(BlockH/2), 0) = temp.select(0, i*BlockW); + out_vnni.select(i*(BlockH/2), 1) = temp.select(1, i*BlockW); + } + } + + template + inline void cm_store_normal(const lsc::block_2d_desc &Desc, vector_ref Res) { + static_assert(NBlocks == 1); + auto pitch = Desc.get_pitch() + 1; + auto base = reinterpret_cast(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T)); + #pragma unroll + for(int i = 0; i < BlockH; i++) { + cm_svm_block_write(base + i * pitch, Res.select(i*BlockW)); + } + } +#endif + +//=============================================================================================== +template +constexpr void apply_causal_mask(matrix_ref St) { + if constexpr (i < N) { + St.row(i).select(0) = -3.4e38f; + apply_causal_mask(St); + } +} + +//prepack [K, N] to [K/2, N, 2] layout. +template +inline void prepackAsVNNIWidth2(matrix_ref input, matrix_ref out) { + #pragma unroll + for (int r = 0; r < K/2; r++) { + out.row(r).select(0) = input.row(r*2); + out.row(r).select(1) = input.row(r*2+1); + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp new file mode 100644 index 00000000000000..f9a7e3fd1aaa03 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp @@ -0,0 +1,505 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#include "cm_attention_common.hpp" + +#if CMPA_KVCACHE_U8 +template +void pa_lsc_u8( + uint slm_K, + uint slm_V, + int wg_local_id, + int local_size, + int q_start, + int kv_stop, + int q_len, + int kv_len, + svmptr_t q_base [[type("svmptr_t")]], + svmptr_t k_cache_base [[type("svmptr_t")]], + svmptr_t v_cache_base [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + svmptr_t sparse_mask_base [[type("svmptr_t")]], + svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], + bool validate, +#endif + svmptr_t o_base [[type("svmptr_t")]], + int32_t past_lens, + int32_t* block_indices [[type("svmptr_t")]]) { + + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_q_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + //[block_num, kv_heads, block_size, head_size] + constexpr uint kv_pitch = head_size * sizeof(uint8_t); + + vector cur_max; + vector cur_sum; + + cur_max = -3e38f; + cur_sum = 0; + constexpr int num_P_tiles = REG_N / REG_M; + matrix rQ; + matrix rO; + + auto q_tokens_left = q_len; + static_assert(q_step == REG_N); + static_assert(kv_step == REG_K); + + if (q_tokens_left < 0) q_tokens_left = 0; + if (q_tokens_left > q_step) q_tokens_left = q_step; + + if (q_tokens_left > 0) { + lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); + #pragma unroll + for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { + cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + + lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + constexpr int quan_blk_stride = CMFLA_NUM_KV_HEADS * (CMFLA_HEAD_SIZE+4) * CMPA_BLOCK_SZ * sizeof(uint8_t); + int causal_left = q_start+past_lens; + + constexpr uint slm_buff_size = kv_step * head_size * sizeof(half); + int slm_buff_id_write = 0; + int slm_buff_id_read = 0; + +#if SPARSE_BLOCK_SIZE > 1 + auto skip_compute = [&](int kv_pos) { + auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); + + return !sparse_mask; + }; + auto skip_load = [&](int kv_pos) { + auto kv_start_block = kv_pos / SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(wg_sparse_mask_base) + kv_start_block); + return !sparse_mask; + }; +#endif + + auto load_slm_KV = [&](int kv_pos) { + if (kv_pos < kv_stop) { +#if SPARSE_BLOCK_SIZE > 1 + if (validate) { + if (skip_load(kv_pos)) { + slm_buff_id_write++; + return; + } + } +#endif + auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; + uint32_t dscale_offset = cur_block_id*quan_blk_stride + \ + CMPA_BLOCK_SZ * head_size * sizeof(uint8_t) + kv_pos%CMPA_BLOCK_SZ*sizeof(half); + + uint slm_offset = (slm_buff_id_write & 3) * slm_buff_size; + vector dscale; + vector zp; + int kv_left = (kv_stop-kv_pos) > kv_step ? kv_step: (kv_stop-kv_pos); + + slm_buff_id_write ++; + if (wg_local_id < local_size/2) { + cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset), dscale); + cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset + CMPA_BLOCK_SZ*sizeof(half)), zp); + + matrix kmat; + auto quanKmat = kmat.format()[1].format(); + b2dK.set_base_ptr(reinterpret_cast(k_cache_base+cur_block_id*quan_blk_stride)); + b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); + + for(int k = REG_K*wg_local_id; k < head_size; k += REG_K*(local_size/2)) { + cm_load(quanKmat.format(), b2dK.set_block_x(k)); + /*@bug: cm compiler in the tail process. + : loop combined with type convert. + for(int r = 0; r < kv_left; r++) { + kmat[r] = quanKmat[r]-zp[r]; + kmat[r] = cm_mul(kmat[r], dscale[r]); + } + wa: unroll all kv_step rows. set 0 to padding rows. + */ + #pragma unroll + for(int r = 0; r < kv_step; r++) { + kmat[r] = quanKmat[r]-zp[r]; + kmat[r] = cm_mul(kmat[r], dscale[r]); + } + //clear unused data to 0. + for(int r = kv_step-1; r >= kv_left; r--) + kmat[r] = 0; + cm_slm_block_write(slm_K, slm_offset + k * kv_step * sizeof(half), kmat.format()); + } + } else { + cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset), dscale); + cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset+CMPA_BLOCK_SZ*sizeof(half)), zp); + + matrix VmatVNNI; + matrix Vmat; + auto quanVmat = Vmat.format().row(1).format(); + b2dV.set_base_ptr(reinterpret_cast(v_cache_base+cur_block_id*quan_blk_stride)); + b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + + #pragma unroll + for(int k = REG_N*(wg_local_id-(local_size/2)); k < head_size; k += REG_N*(local_size/2)) { + cm_load(quanVmat.format(), b2dV.set_block_x(k)); + /*@bug: cm compiler in the tail process. + : loop combined with type convert. + for(int r = 0; r < kv_left; r++) { + Vmat[r] = quanVmat[r]-zp[r]; + Vmat[r] = cm_mul(Vmat[r], dscale[r]); + } + */ + #pragma unroll + for(int r = 0; r < kv_step;r++) { + Vmat[r] = quanVmat[r]-zp[r]; + Vmat[r] = cm_mul(Vmat[r], dscale[r]); + } + + for(int r = kv_step-1; r>=kv_left;r--) { + Vmat[r] = 0; + } + prepackAsVNNIWidth2(Vmat, VmatVNNI); + cm_slm_block_write(slm_V, slm_offset + k * REG_K * sizeof(half), VmatVNNI.format()); + } + } + } + }; + + load_slm_KV(0); + load_slm_KV(kv_step); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(1); + + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step,slm_buff_id_read++) { + + // load0, load1, signal1, + // [wait1, signal2, load2, read0, compute0] + // [wait2, signal3, load3, read1, compute1] + // [wait3, signal4, load4, read2, compute2] + // [wait4, signal5, load5, read3, compute3] + // + // after wait3, all workers have reached signal3, so: + // - all workers have finished load2 & read0. + // - we can start to load 4 into SLM slot 0 (i & 3) safely + // - we can start to read 2 ((i-2) & 3) safely + + + cm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(0); + //if (kv_pos > 1024000) + if (kv_pos + kv_step < kv_stop) + cm_sbarrier(1); + load_slm_KV(kv_pos + kv_step*2); + + +#if SPARSE_BLOCK_SIZE > 1 + if (validate) { + if (skip_compute(kv_pos)) { + if constexpr (use_causal_mask) + causal_left -= kv_step; + continue; + } + } +#endif + { + + uint slm_offset = (slm_buff_id_read & 3) * slm_buff_size; + + //# St = k @ Qt + matrix St = ugemm_KQ(slm_K, rQ, slm_offset); + if constexpr (use_causal_mask) { + // since kv_step == q_step == 16, causal_left is n*kv_step + if (causal_left == 0) { + apply_causal_mask<1>(St); + } else if (causal_left < 0) { + St = -3.4e38f; + } + causal_left -= kv_step; + } else { + int kv_tokens = kv_stop - kv_pos; + // LSC ensures no overflow-access, but mask off k-tails attn-score is still required + for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; + } + auto max_comp = online_softmax_update(St, cur_max, cur_sum); + + matrix P; + Transpose2DMatrix(St, P); + + if (kv_pos == 0) + ugemm_PV0(slm_V, P, rO, slm_offset); + else + ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); + } + } + // cm_sbarrier(0); + if (q_tokens_left == 0) return; + + //# save cur_O/cur_sum.transpose(0, 1) + matrix cur_O_f16; + cur_sum = cm_inv(cur_sum); + + lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); + + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < cO.n_rows(); r++) { + cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + } + } + b2dO.set_block_x(k); + cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); + cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); + } +} + +#else + +template +void pa_kernel_lsc_prefetch_f16( + int wg_local_id, + int q_start, + int kv_stop, // + int q_len, //q_step + int kv_len, //not used for now + svmptr_t q_base [[type("svmptr_t")]], + svmptr_t k_cache_base [[type("svmptr_t")]], + svmptr_t v_cache_base [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + svmptr_t sparse_mask_base [[type("svmptr_t")]], + svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], + bool validate, +#endif + svmptr_t o_base [[type("svmptr_t")]], + int32_t past_lens, + int32_t* block_indices [[type("svmptr_t")]]) { + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; + // constexpr uint k_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + // constexpr uint v_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + //[block_num, kv_heads, block_size, head_size] + constexpr uint k_pitch = head_size * sizeof(half); + constexpr uint v_pitch = k_pitch; + + vector cur_max; + vector cur_sum; + + cur_max = -3e38f; + cur_sum = 0; + constexpr int num_P_tiles = REG_N / REG_M; + matrix rQ; + matrix rO; + + auto q_tokens_left = q_len;// - q_start; + static_assert(q_step == REG_N); + static_assert(kv_step == REG_K); + + if (q_tokens_left < 0) q_tokens_left = 0; + if (q_tokens_left > q_step) q_tokens_left = q_step; + + if (q_tokens_left > 0) { + lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); + #pragma unroll + for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { + cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + + lsc::block_2d_desc b2dK(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc b2dV(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + + static_assert(wg_local_size == 16); + lsc::block_2d_desc prefetch_K(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0); + lsc::block_2d_desc prefetch_V(v_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0); + constexpr int blk_stride = CMFLA_NUM_KV_HEADS*CMFLA_HEAD_SIZE*CMPA_BLOCK_SZ; + int causal_left = q_start+past_lens; + + for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) { + auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; + //For the last step, duplicate prefetch here. + uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); + auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; + //# St = k @ Qt + matrix St; + { + constexpr int num_K = kv_step/REG_M; + auto St2 = St.format(); + + matrix Kmat; + + prefetch_K.set_base_ptr((reinterpret_cast(k_cache_base)+prefetch_block_id*blk_stride)); + prefetch_K.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + cm_prefetch(prefetch_K.set_block_x(0)); + +#if SPARSE_BLOCK_SIZE > 1 + if (validate) + { + auto kv_start_block = kv_pos/ SPARSE_BLOCK_SIZE; + bool sparse_mask = *(reinterpret_cast(sparse_mask_base) + kv_start_block); + if (!sparse_mask) { + if constexpr (use_causal_mask) { + causal_left -= kv_step; + } + continue; + } + } +#endif + b2dK.set_base_ptr((reinterpret_cast(k_cache_base)+cur_block_id*blk_stride)); + b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ); + cm_load(Kmat.format(), b2dK.set_block_x(0)); + // somtimes KV cache would be filled with random Nan, so need to clean up the unused key data. + if ((kv_pos + kv_step) > kv_stop) { + auto valid_rows = kv_stop - kv_pos; + for (int r = valid_rows; r < kv_step; r++) + Kmat.format().row(r) = 0.f; + } + #pragma unroll + for(int k = 0; k < num_K; k++) + St2.row(k) = cm_dpas( + 0, + rQ[0].format(), + Kmat[k].format()); + + #pragma unroll + for(int ri = 1; ri < head_size/REG_K; ri++) { + cm_prefetch(prefetch_K.set_block_x(ri*REG_K)); + cm_load(Kmat.format(), b2dK.set_block_x(ri*REG_K)); + #pragma unroll + for(int k = 0; k < num_K; k++) { + St2.row(k) = cm_dpas( + St2.row(k), + rQ[ri].format(), + Kmat[k].format()); + } + } + } + if constexpr (use_causal_mask) { + // since kv_step == q_step == 16, causal_left is n*kv_step + if (causal_left == 0) { + apply_causal_mask<1>(St); + } else if (causal_left < 0) { + St = -3.4e38f; + } + causal_left -= kv_step; + } else { + int kv_tokens = kv_stop - kv_pos; + // LSC ensures no overflow-access, but mask off k-tails attn-score is still required + for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; + } + + //show(St); + auto max_comp = online_softmax_update(St, cur_max, cur_sum); + + matrix P; + Transpose2DMatrix(St, P); + + prefetch_V.set_base_ptr((reinterpret_cast(v_cache_base)+prefetch_block_id*blk_stride)); + prefetch_V.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + + b2dV.set_base_ptr((reinterpret_cast(v_cache_base)+cur_block_id*blk_stride)); + b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + if (kv_pos == 0) { + // ugemm_PV0(slm_V, P, rO, slm_offset); + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_prefetch(prefetch_V.set_block_x(k)); + cm_load(Vmat.format(), b2dV.set_block_x(k)); + // somtimes KV cache would be filled with random Nan, so need to clean up the unused value data. + if ((kv_pos + kv_step) > kv_stop) { + uint valid_rows = kv_stop - kv_pos; + uint valid_rows_vnni = (valid_rows+1)/2; + for (int r = valid_rows_vnni; r < kv_step / 2; r++) + Vmat.row(r) = 0.f; + if (valid_rows % 2 == 1) + Vmat.row(valid_rows_vnni-1).select(1) = 0.f; + } + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + 0, + Vmat.format(), + P2.row(p).format()); + } + } + } + else { + //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + + cm_prefetch(prefetch_V.set_block_x(k)); + cm_load(Vmat.format(), b2dV.set_block_x(k)); + // somtimes KV cache would be filled with random Nan, so need to clean up the unused value data. + if ((kv_pos + kv_step) > kv_stop) { + uint valid_rows = kv_stop - kv_pos; + uint valid_rows_vnni = (valid_rows+1)/2; + for (int r = valid_rows_vnni; r < kv_step / 2; r++) + Vmat.row(r) = 0.f; + if (valid_rows % 2 == 1) + Vmat.row(valid_rows_vnni-1).select(1) = 0.f; + } + //# compensate cur_O + // matrix rO; + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < REG_M; r++) + cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); + } + + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + rO[ri + p].format(), + Vmat.format(), + P2.row(p).format()); + } + } + } + } + if (q_tokens_left == 0) return; + + //# save cur_O/cur_sum.transpose(0, 1) + matrix cur_O_f16; + cur_sum = cm_inv(cur_sum); + + lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); + + #pragma unroll + for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < cO.n_rows(); r++) { + cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + + } + } + b2dO.set_block_x(k); + cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); + cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); + } +} + +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp index 74fe045cfff4a6..402cacb2e77674 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp @@ -13,315 +13,258 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include -#include - -//# CM-compiler is C++17 -static_assert(__cplusplus >= 201703L); - -#define SystolicDepth 8 -#define RepeatCount 8 -#define VNNI_WIDTH 2 -#define REG_K (SystolicDepth * VNNI_WIDTH) -#define REG_M RepeatCount -//REG_N -// Xe1: 8 -// Xe2: 16 -#define REG_N (CM_GRF_WIDTH/32) - -#define kv_step REG_K -#define q_step REG_N - -constexpr float scale_factor = CMFLA_SCALE_FACTOR; - -static_assert(q_step == 16 || q_step == 8); -static_assert(kv_step == 16); -static_assert(CM_HAS_DPAS); - -template -void show(const matrix mat) { - printf("Matrix [%d, %d]:\n", M, N); - for(int m = 0; m < M; m ++) { - printf("\t["); - for(int n = 0; n < N; n ++) { - printf("%8.4f,", mat[m][n]); - } - printf("],\n"); - } - printf("]\n"); -} - -template -CM_INLINE void Transpose_16x16(matrix_ref in, - matrix_ref out) { - matrix bBuf; - bBuf.row(0) = in.template select<4, 1, 4, 4>(0, 0); // 0,4,8,c - bBuf.row(1) = in.template select<4, 1, 4, 4>(4, 0); // 0,4,8,c - bBuf.row(2) = in.template select<4, 1, 4, 4>(8, 0); // 0,4,8,c - bBuf.row(3) = in.template select<4, 1, 4, 4>(12, 0); // 0,4,8,c - bBuf.row(4) = in.template select<4, 1, 4, 4>(0, 1); // 1,5,9,d - bBuf.row(5) = in.template select<4, 1, 4, 4>(4, 1); // 1,5,9,d - bBuf.row(6) = in.template select<4, 1, 4, 4>(8, 1); // 1,5,9,d - bBuf.row(7) = in.template select<4, 1, 4, 4>(12, 1); // 1,5,9,d - bBuf.row(8) = in.template select<4, 1, 4, 4>(0, 2); // 2,6,a,e - bBuf.row(9) = in.template select<4, 1, 4, 4>(4, 2); // 2,6,a,e - bBuf.row(10) = in.template select<4, 1, 4, 4>(8, 2); // 2,6,a,e - bBuf.row(11) = in.template select<4, 1, 4, 4>(12, 2); // 2,6,a,e - bBuf.row(12) = in.template select<4, 1, 4, 4>(0, 3); // 3,7,b,f - bBuf.row(13) = in.template select<4, 1, 4, 4>(4, 3); // 3,7,b,f - bBuf.row(14) = in.template select<4, 1, 4, 4>(8, 3); // 3,7,b,f - bBuf.row(15) = in.template select<4, 1, 4, 4>(12, 3); // 3,7,b,f - - out.row(0) = bBuf.template select<4, 1, 4, 4>(0, 0); // 0 - out.row(1) = bBuf.template select<4, 1, 4, 4>(4, 0); // 1 - out.row(2) = bBuf.template select<4, 1, 4, 4>(8, 0); // 2 - out.row(3) = bBuf.template select<4, 1, 4, 4>(12, 0); // 3 - out.row(4) = bBuf.template select<4, 1, 4, 4>(0, 1); // 4 - out.row(5) = bBuf.template select<4, 1, 4, 4>(4, 1); // 5 - out.row(6) = bBuf.template select<4, 1, 4, 4>(8, 1); // 6 - out.row(7) = bBuf.template select<4, 1, 4, 4>(12, 1); // 7 - out.row(8) = bBuf.template select<4, 1, 4, 4>(0, 2); // 8 - out.row(9) = bBuf.template select<4, 1, 4, 4>(4, 2); // 9 - out.row(10) = bBuf.template select<4, 1, 4, 4>(8, 2); // a - out.row(11) = bBuf.template select<4, 1, 4, 4>(12, 2); // b - out.row(12) = bBuf.template select<4, 1, 4, 4>(0, 3); // c - out.row(13) = bBuf.template select<4, 1, 4, 4>(4, 3); // d - out.row(14) = bBuf.template select<4, 1, 4, 4>(8, 3); // e - out.row(15) = bBuf.template select<4, 1, 4, 4>(12, 3); // f -} - -template -CM_INLINE void Transpose_8x8(matrix_ref in, matrix_ref out) { - matrix temp; - temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); - temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); - temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); - temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); - temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); - temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); - temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); - temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); - - out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); - out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); - out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); - out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); - out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); - out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); - out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); - out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); -} - -// function templates cannot be partially specialized; use overloading to achieve the same effect -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_8x8(in, out); -} -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_16x16(in, out); -} -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); - Transpose_8x8(in.select<8, 1, 8, 1>(8,0), out.select<8, 1, 8, 1>(0,8)); -} -template -inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { - Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); - Transpose_8x8(in.select<8, 1, 8, 1>(0,8), out.select<8, 1, 8, 1>(8,0)); -} - -template -CM_INLINE void slm_read_2d(matrix_ref out, uint slm, int offset) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_slm_block_read(slm, GENX_DWALIGNED, offset + i*n_stride*sizeof(T), out.row(i)); - } -} - -template -CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_svm_block_read(base + i * pitch, out[i]); - } -} - -template -CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - out.row(i).format() = cm_load(base, offset + i * pitch); - } -} - -template -CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - out.row(i).format() = cm_load(base, offset + i * pitch); - } -} - -template -CM_INLINE void cm_store_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_store(base, offset + i * pitch, out.row(i).format()); - } -} - -template -CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, vector_ref offsets) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++) { - cm_svm_block_read(base + offsets[i], out[i]); - } -} - -template -CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++, base += pitch, n_rows--) { - if (n_rows > 0) cm_svm_block_read(base, out[i]); - } -} - -template -CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++, base += pitch) { - cm_svm_block_write(base, out[i]); - } -} - -template -CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { - #pragma unroll - for(int i = 0; i < out.n_rows(); i++, base += pitch) { - if (i < n_rows) cm_svm_block_write(base, out[i]); - } -} - -CM_INLINE uint64_t get_clock() { - auto clk = cm_clock(); - return ((uint64_t)clk[1]) << 32 | clk[0]; -} - - -template -inline matrix ugemm_KQ(uint slm_K, matrix_ref Qt, uint slm_offset = 0) { - matrix St; - constexpr int num_K = _kv_step/REG_M; - auto St2 = St.format(); - - matrix Kmat; - cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); - #pragma unroll - for(int k = 0; k < num_K; k++) - St2.row(k) = cm_dpas(0, Qt[0].format(), Kmat[k].format()); - - #pragma unroll - for(int ri = 1; ri < num_Qt; ri++) { - cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); - #pragma unroll - for(int k = 0; k < num_K; k++) { - St2.row(k) = cm_dpas(St2.row(k), Qt[ri].format(), Kmat[k].format()); - } - } - return St; -} - -template -inline void ugemm_PV0(uint slm_V, matrix_ref P, matrix_ref rO, uint slm_offset = 0) { - constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; - - auto P2 = P.format(); - #pragma unroll - for(int k = 0, ri = 0; k < _head_size; k += REG_N, ri += num_P_tiles) { - matrix Vmat; - cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); - - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - rO[ri + p] = cm_dpas( - 0, - Vmat.format(), - P2.row(p).format()); - //show(rO[ri + p].format()); - } - } -} - -template -inline void ugemm_PV1(uint slm_V, matrix_ref P, vector_ref max_comp, - matrix_ref rO, uint slm_offset = 0) { - constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; - auto P2 = P.format(); - #pragma unroll - for(int k = 0, ri=0; k < _head_size; k += REG_N, ri += num_P_tiles) { - matrix Vmat; - cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); - - //# compensate cur_O - // matrix rO; - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - auto cO = rO[ri + p].format(); - #pragma unroll - for(int r = 0; r < REG_M; r++) - cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); - } - - //show(rO[ri].format()); - - //# show(cur_O.format()); return; - #pragma unroll - for(int p = 0; p < num_P_tiles; p++) { - rO[ri + p] = cm_dpas( - rO[ri + p].format(), - Vmat.format(), - P2.row(p).format()); - //if (kv_pos == args_verbose) show(rO[ri + p].format()); - } - // if (kv_pos == args_verbose) show(cur_O.format()); - } -} - -template -vector online_softmax_update(matrix_ref St, vector_ref cur_max, vector_ref cur_sum) { - vector new_max_t; - new_max_t = cm_max(St[0], St[1]); - for(int r = 2; r < St.n_rows(); r++) new_max_t = cm_max(new_max_t, St[r]); - new_max_t = cm_max(new_max_t, cur_max); - - // Pt = torch.exp(St - new_max) - constexpr float log2e = 1.4426950408889634f; - for(int r = 0; r < St.n_rows(); r++) St[r] = cm_exp((St[r] - new_max_t)*log2e); - - vector row_sum_t; - row_sum_t = cm_add(St[0], St[1]); - for(int r = 2; r < St.n_rows(); r++) row_sum_t = cm_add(row_sum_t, St[r]); - - vector max_comp; - max_comp = cm_exp((cur_max - new_max_t)*log2e); - cur_sum = cm_mul(cur_sum, max_comp); - cur_sum = cm_add(cur_sum, row_sum_t); - cur_max = new_max_t; - return max_comp; -} - -//=============================================================================================== -template -constexpr void apply_causal_mask(matrix_ref St) { - if constexpr (i < N) { - St.row(i).select(0) = -3.4e38f; - apply_causal_mask(St); - } -} +#include "cm_attention_common.hpp" #ifdef CM_HAS_LSC_UNTYPED_2D +//@prefetch_u8 would have duplicated decompress perf issue. comments out for now. +// template +// void sdpa_kernel_lsc_prefetch_u8( +// int wg_local_id, +// int q_start, +// int kv_stop, // +// int q_len, //q_step +// int kv_len, //not used for now +// svmptr_t q_base [[type("svmptr_t")]], +// svmptr_t k_cache_base [[type("svmptr_t")]], +// svmptr_t v_cache_base [[type("svmptr_t")]], +// svmptr_t o_base [[type("svmptr_t")]], +// int32_t past_lens, +// int32_t* block_indices [[type("svmptr_t")]]) { +// constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); +// constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; +// //[block_num, kv_heads, block_size, head_size] +// constexpr uint kv_pitch = head_size * sizeof(uint8_t); + +// vector cur_max; +// vector cur_sum; + +// cur_max = -3e38f; +// cur_sum = 0; +// constexpr int num_P_tiles = REG_N / REG_M; +// matrix rQ; +// matrix rO; + +// auto q_tokens_left = q_len;// - q_start; +// static_assert(q_step == REG_N); +// static_assert(kv_step == REG_K); + +// if (q_tokens_left < 0) q_tokens_left = 0; +// if (q_tokens_left > q_step) q_tokens_left = q_step; + +// if (q_tokens_left > 0) { +// lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); +// #pragma unroll +// for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { +// cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); +// rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); +// } +// } + +// lsc::block_2d_desc b2dKV(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + +// static_assert(wg_local_size == 16); +// lsc::block_2d_desc b2dKV_prefetch(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); +// // constexpr int blk_stride = CMFLA_NUM_KV_HEADS * CMFLA_HEAD_SIZE * CMPA_BLOCK_SZ; +// constexpr int quan_blk_stride = CMFLA_NUM_KV_HEADS * (CMFLA_HEAD_SIZE+4) * CMPA_BLOCK_SZ * sizeof(uint8_t); + + + +// int causal_left = q_start+past_lens; +// for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) { +// auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; +// //For the last step, duplicate prefetch here. +// uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); +// auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; +// uint32_t dscale_offset = cur_block_id*quan_blk_stride + CMPA_BLOCK_SZ * head_size * sizeof(uint8_t) + kv_pos%CMPA_BLOCK_SZ*sizeof(half); + +// vector dscale; +// vector zp; +// cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset), dscale); +// cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset + CMPA_BLOCK_SZ*sizeof(half)), zp); + +// // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { +// // show(dscale.format()); +// // } +// //# St = k @ Qt + +// matrix St; // = ugemm_KQ(slm_K, rQ, slm_offset); +// { +// constexpr int num_K = kv_step/REG_M; +// auto St2 = St.format(); +// matrix Kmat; +// auto quan_Kmat = Kmat.format().row(1).format(); +// auto dq_Kmat = Kmat.format(); +// //cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + +// b2dKV_prefetch.set_base_ptr(reinterpret_cast(k_cache_base+prefetch_block_id*quan_blk_stride)); +// b2dKV_prefetch.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); +// cm_prefetch(b2dKV_prefetch.set_block_x(0)); + +// b2dKV.set_base_ptr(reinterpret_cast(k_cache_base+cur_block_id*quan_blk_stride)); +// b2dKV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + +// cm_load(quan_Kmat.format(), b2dKV.set_block_x(0)); +// // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { +// // show(quan_Kmat.format(), false); +// // } +// #pragma unroll +// for(int r = 0; r < kv_step; r++) { +// dq_Kmat[r] = quan_Kmat[r] - zp[r]; +// dq_Kmat[r] = cm_mul(dq_Kmat[r], dscale[r]); +// } + +// #pragma unroll +// for(int k = 0; k < num_K; k++) +// St2.row(k) = cm_dpas( +// 0, +// rQ[0].format(), +// Kmat[k].format()); + +// #pragma unroll +// for(int ri = 1; ri < head_size/REG_K; ri++) { +// cm_prefetch(b2dKV_prefetch.set_block_x(ri*REG_K)); +// //cm_load(Kmat.format(), b2dKV.set_block_x(ri*REG_K)); +// cm_load(quan_Kmat.format(), b2dKV.set_block_x(ri*REG_K)); +// #pragma unroll +// for(int r = 0; r < kv_step; r++) { +// dq_Kmat[r] = quan_Kmat[r] - zp[r]; +// dq_Kmat[r] = cm_mul(dq_Kmat[r], dscale[r]); +// } +// #pragma unroll +// for(int k = 0; k < num_K; k++) { +// St2.row(k) = cm_dpas( +// St2.row(k), +// rQ[ri].format(), +// Kmat[k].format()); +// } +// } +// } +// if constexpr (use_causal_mask) { +// // since kv_step == q_step == 16, causal_left is n*kv_step +// if (causal_left == 0) { +// apply_causal_mask<1>(St); +// } else if (causal_left < 0) { +// St = -3.4e38f; +// } +// causal_left -= kv_step; +// } else { +// int kv_tokens = kv_stop - kv_pos; +// // LSC ensures no overflow-access, but mask off k-tails attn-score is still required +// for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; +// } + +// //show(St); +// auto max_comp = online_softmax_update(St, cur_max, cur_sum); + +// matrix P; +// Transpose2DMatrix(St, P); + +// b2dKV_prefetch.set_base_ptr(reinterpret_cast(v_cache_base+prefetch_block_id*quan_blk_stride)); +// b2dKV_prefetch.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); + +// b2dKV.set_base_ptr(reinterpret_cast(v_cache_base+cur_block_id*quan_blk_stride)); +// b2dKV.set_block_y(kv_pos%CMPA_BLOCK_SZ); + + +// cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset), dscale); +// cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset+CMPA_BLOCK_SZ*sizeof(half)), zp); + +// { +// matrix VmatVNNI2; +// matrix Vmat; +// auto quanVmat = Vmat.format().row(1).format(); +// int kv_tokens = kv_stop - kv_pos; +// if (kv_pos == 0) { +// // ugemm_PV0(slm_V, P, rO, slm_offset); +// auto P2 = P.format(); +// #pragma unroll +// for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { +// cm_prefetch(b2dKV_prefetch.set_block_x(k)); +// cm_load(quanVmat.format(), b2dKV.set_block_x(k)); +// #pragma unroll +// for(int r = 0; r < kv_step;r++) { +// Vmat[r] = quanVmat[r]-zp[r]; +// Vmat[r] = cm_mul(Vmat[r], dscale[r]); +// } +// for(int r = kv_step-1; r>=kv_tokens;r--) { +// Vmat[r] = 0; +// } + +// prepackAsVNNIWidth2(Vmat, VmatVNNI2); + +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// rO[ri + p] = cm_dpas( +// 0, +// VmatVNNI2.format(), +// P2.row(p).format()); +// } +// } +// } +// else { +// //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); +// auto P2 = P.format(); +// #pragma unroll +// for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { +// cm_prefetch(b2dKV_prefetch.set_block_x(k)); +// cm_load(quanVmat.format(), b2dKV.set_block_x(k)); + +// #pragma unroll +// for(int r = 0; r < kv_step;r++) { +// Vmat[r] = quanVmat[r]-zp[r]; +// Vmat[r] = cm_mul(Vmat[r], dscale[r]); +// } +// for(int r = kv_step-1; r>=kv_tokens;r--) { +// Vmat[r] = 0; +// } + +// prepackAsVNNIWidth2(Vmat, VmatVNNI2); +// //# compensate cur_O +// // matrix rO; +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// auto cO = rO[ri + p].format(); +// #pragma unroll +// for(int r = 0; r < REG_M; r++) +// cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); +// } + +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// rO[ri + p] = cm_dpas( +// rO[ri + p].format(), +// VmatVNNI2.format(), +// P2.row(p).format()); +// } +// } +// } +// } +// } +// if (q_tokens_left == 0) return; + +// //# save cur_O/cur_sum.transpose(0, 1) +// matrix cur_O_f16; +// cur_sum = cm_inv(cur_sum); + +// lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); + +// #pragma unroll +// for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { +// #pragma unroll +// for(int p = 0; p < num_P_tiles; p++) { +// auto cO = rO[ri + p].format(); +// #pragma unroll +// for(int r = 0; r < cO.n_rows(); r++) { +// cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); + +// } +// } +// b2dO.set_block_x(k); +// cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); +// cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); +// } +// } + template void sdpa_kernel_lsc( uint slm_K, @@ -482,7 +425,6 @@ void sdpa_kernel_lsc( } } - template void sdpa_kernel_lsc_prefetch( int wg_local_id, @@ -662,7 +604,8 @@ void sdpa_kernel_lsc_prefetch( cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); } } -#endif + +#else // CM_HAS_LSC_UNTYPED_2D template void sdpa_kernel( @@ -865,3 +808,5 @@ void sdpa_kernel( } } } + +#endif // !CM_HAS_LSC_UNTYPED_2D \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp new file mode 100644 index 00000000000000..e34351db241d8d --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/estimate.hpp @@ -0,0 +1,1142 @@ +/* + * Copyright (c) 2020-2023, Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +#include +#include + +#if defined(SHIM) || defined(CMRT_EMU) +#define ATTR +#define ATTR_BUF +#define CM_LOCAL_BARRIER 0x20 +#include "emu/block2d.h" +#else +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +#define MYMIN(x, y) ((x) < (y) ? (x) : (y)) + +template +void show(const vector mat) { + printf("vector [%d]:\n[", N); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[n]); + } + printf("]\n"); +} + +template +void show_i(const vector mat) { + printf("vector [%d]:\n[", N); + for(int n = 0; n < N; n ++) { + printf("%d,", mat[n]); + } + printf("]\n"); +} + +template +void show(const matrix mat) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} + +template +void show(const matrix_ref mat) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} + +template +CM_INLINE void Transpose_8x8(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); + temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); + temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); + temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); + temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); + temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); + temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); + temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); + + out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); + out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); + out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); + out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); + out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); + out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); + out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); + out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); +} + +template +CM_INLINE void Transpose_8x32(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 0), out.template select<8, 1, 8, 1>( 0, 0)); + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 8), out.template select<8, 1, 8, 1>( 8, 0)); + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 16), out.template select<8, 1, 8, 1>(16, 0)); + Transpose_8x8(in.template select<8, 1, 8, 1>(0, 24), out.template select<8, 1, 8, 1>(24, 0)); +} + +template +CM_INLINE void Transpose_4x32(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row(0) = in.template select<4, 1, 8, 4>(0, 0); + temp.row(1) = in.template select<4, 1, 8, 4>(0, 1); + temp.row(2) = in.template select<4, 1, 8, 4>(0, 2); + temp.row(3) = in.template select<4, 1, 8, 4>(0, 3); + + out.row( 0) = temp.template select<1, 1, 4, 8>(0, 0); + out.row( 1) = temp.template select<1, 1, 4, 8>(1, 0); + out.row( 2) = temp.template select<1, 1, 4, 8>(2, 0); + out.row( 3) = temp.template select<1, 1, 4, 8>(3, 0); + out.row( 4) = temp.template select<1, 1, 4, 8>(0, 1); + out.row( 5) = temp.template select<1, 1, 4, 8>(1, 1); + out.row( 6) = temp.template select<1, 1, 4, 8>(2, 1); + out.row( 7) = temp.template select<1, 1, 4, 8>(3, 1); + out.row( 8) = temp.template select<1, 1, 4, 8>(0, 2); + out.row( 9) = temp.template select<1, 1, 4, 8>(1, 2); + out.row(10) = temp.template select<1, 1, 4, 8>(2, 2); + out.row(11) = temp.template select<1, 1, 4, 8>(3, 2); + out.row(12) = temp.template select<1, 1, 4, 8>(0, 3); + out.row(13) = temp.template select<1, 1, 4, 8>(1, 3); + out.row(14) = temp.template select<1, 1, 4, 8>(2, 3); + out.row(15) = temp.template select<1, 1, 4, 8>(3, 3); + out.row(16) = temp.template select<1, 1, 4, 8>(0, 4); + out.row(17) = temp.template select<1, 1, 4, 8>(1, 4); + out.row(18) = temp.template select<1, 1, 4, 8>(2, 4); + out.row(19) = temp.template select<1, 1, 4, 8>(3, 4); + out.row(20) = temp.template select<1, 1, 4, 8>(0, 5); + out.row(21) = temp.template select<1, 1, 4, 8>(1, 5); + out.row(22) = temp.template select<1, 1, 4, 8>(2, 5); + out.row(23) = temp.template select<1, 1, 4, 8>(3, 5); + out.row(24) = temp.template select<1, 1, 4, 8>(0, 6); + out.row(25) = temp.template select<1, 1, 4, 8>(1, 6); + out.row(26) = temp.template select<1, 1, 4, 8>(2, 6); + out.row(27) = temp.template select<1, 1, 4, 8>(3, 6); + out.row(28) = temp.template select<1, 1, 4, 8>(0, 7); + out.row(29) = temp.template select<1, 1, 4, 8>(1, 7); + out.row(30) = temp.template select<1, 1, 4, 8>(2, 7); + out.row(31) = temp.template select<1, 1, 4, 8>(3, 7); +} + +template +CM_INLINE void Transpose_32x32(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row( 0) = in.template select<8, 1, 4, 8>( 0, 0); + temp.row( 1) = in.template select<8, 1, 4, 8>( 8, 0); + temp.row( 2) = in.template select<8, 1, 4, 8>(16, 0); + temp.row( 3) = in.template select<8, 1, 4, 8>(24, 0); + temp.row( 4) = in.template select<8, 1, 4, 8>( 0, 1); + temp.row( 5) = in.template select<8, 1, 4, 8>( 8, 1); + temp.row( 6) = in.template select<8, 1, 4, 8>(16, 1); + temp.row( 7) = in.template select<8, 1, 4, 8>(24, 1); + temp.row( 8) = in.template select<8, 1, 4, 8>( 0, 2); + temp.row( 9) = in.template select<8, 1, 4, 8>( 8, 2); + temp.row(10) = in.template select<8, 1, 4, 8>(16, 2); + temp.row(11) = in.template select<8, 1, 4, 8>(24, 2); + temp.row(12) = in.template select<8, 1, 4, 8>( 0, 3); + temp.row(13) = in.template select<8, 1, 4, 8>( 8, 3); + temp.row(14) = in.template select<8, 1, 4, 8>(16, 3); + temp.row(15) = in.template select<8, 1, 4, 8>(24, 3); + temp.row(16) = in.template select<8, 1, 4, 8>( 0, 4); + temp.row(17) = in.template select<8, 1, 4, 8>( 8, 4); + temp.row(18) = in.template select<8, 1, 4, 8>(16, 4); + temp.row(19) = in.template select<8, 1, 4, 8>(24, 4); + temp.row(20) = in.template select<8, 1, 4, 8>( 0, 5); + temp.row(21) = in.template select<8, 1, 4, 8>( 8, 5); + temp.row(22) = in.template select<8, 1, 4, 8>(16, 5); + temp.row(23) = in.template select<8, 1, 4, 8>(24, 5); + temp.row(24) = in.template select<8, 1, 4, 8>( 0, 6); + temp.row(25) = in.template select<8, 1, 4, 8>( 8, 6); + temp.row(26) = in.template select<8, 1, 4, 8>(16, 6); + temp.row(27) = in.template select<8, 1, 4, 8>(24, 6); + temp.row(28) = in.template select<8, 1, 4, 8>( 0, 7); + temp.row(29) = in.template select<8, 1, 4, 8>( 8, 7); + temp.row(30) = in.template select<8, 1, 4, 8>(16, 7); + temp.row(31) = in.template select<8, 1, 4, 8>(24, 7); + + out.row( 0) = temp.template select<4, 1, 8, 4>( 0, 0); + out.row( 1) = temp.template select<4, 1, 8, 4>( 4, 0); + out.row( 2) = temp.template select<4, 1, 8, 4>( 8, 0); + out.row( 3) = temp.template select<4, 1, 8, 4>(12, 0); + out.row( 4) = temp.template select<4, 1, 8, 4>(16, 0); + out.row( 5) = temp.template select<4, 1, 8, 4>(20, 0); + out.row( 6) = temp.template select<4, 1, 8, 4>(24, 0); + out.row( 7) = temp.template select<4, 1, 8, 4>(28, 0); + out.row( 8) = temp.template select<4, 1, 8, 4>( 0, 1); + out.row( 9) = temp.template select<4, 1, 8, 4>( 4, 1); + out.row(10) = temp.template select<4, 1, 8, 4>( 8, 1); + out.row(11) = temp.template select<4, 1, 8, 4>(12, 1); + out.row(12) = temp.template select<4, 1, 8, 4>(16, 1); + out.row(13) = temp.template select<4, 1, 8, 4>(20, 1); + out.row(14) = temp.template select<4, 1, 8, 4>(24, 1); + out.row(15) = temp.template select<4, 1, 8, 4>(28, 1); + out.row(16) = temp.template select<4, 1, 8, 4>( 0, 2); + out.row(17) = temp.template select<4, 1, 8, 4>( 4, 2); + out.row(18) = temp.template select<4, 1, 8, 4>( 8, 2); + out.row(19) = temp.template select<4, 1, 8, 4>(12, 2); + out.row(20) = temp.template select<4, 1, 8, 4>(16, 2); + out.row(21) = temp.template select<4, 1, 8, 4>(20, 2); + out.row(22) = temp.template select<4, 1, 8, 4>(24, 2); + out.row(23) = temp.template select<4, 1, 8, 4>(28, 2); + out.row(24) = temp.template select<4, 1, 8, 4>( 0, 3); + out.row(25) = temp.template select<4, 1, 8, 4>( 4, 3); + out.row(26) = temp.template select<4, 1, 8, 4>( 8, 3); + out.row(27) = temp.template select<4, 1, 8, 4>(12, 3); + out.row(28) = temp.template select<4, 1, 8, 4>(16, 3); + out.row(29) = temp.template select<4, 1, 8, 4>(20, 3); + out.row(30) = temp.template select<4, 1, 8, 4>(24, 3); + out.row(31) = temp.template select<4, 1, 8, 4>(28, 3); +} + +// group_count: M = group_count * group_size, group_size is the element count of current reduction +// op: 0-max, 1-sum +// M: before reduce element count, N: row count, stop: element count M must be larger than +template +CM_INLINE constexpr auto reduce2d(matrix_ref src) { + constexpr int group_size = M / group_count; + if constexpr (N > stop) { + matrix result; + // half of group will be reduced + constexpr int new_group_size = group_size / 2; + constexpr int new_group_count = group_count * 2; +#pragma unroll + for (int i = 0; i < N / 2; i++) { + matrix new_top, new_bot; + auto top = src.row(2 * i + 0).format(); + auto bot = src.row(2 * i + 1).format(); + constexpr int v_stride = new_group_count == 2 ? 1 : 2; + + new_top.select(0) = top.select(0); + new_top.select(new_group_count / 2) = bot.select(0); + new_bot.select(0) = top.select(1); + new_bot.select(new_group_count / 2) = bot.select(1); + if constexpr (op == 0) { + result[i] = cm_max(new_top.format(), new_bot.format()); + } else { + result[i] = (new_top.format() + new_bot.format()); + } + } + + return reduce2d(result); + } else { + matrix dst = src; + return dst; + } +} + +template +CM_INLINE void read_1d(vector_ref out, svmptr_t base) { + cm_ptr_block_read((T*)base, out); +} + +template +CM_INLINE void read_2d(matrix_ref out, svmptr_t base, uint pitch) { +#pragma unroll + for (int i = 0; i < out.n_rows(); i++, base += pitch) { + cm_ptr_block_read((T*)base, out.row(i)); + } +} + +template +CM_INLINE void write_2d(matrix_ref out, svmptr_t base, uint pitch) { +#pragma unroll + for (int i = 0; i < out.n_rows(); i++, base += pitch) { + cm_ptr_block_write((TSRC*)base, out.row(i)); + } +} + +template +CM_INLINE void write_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { +#pragma unroll + for (int i = 0; i < out.n_rows(); i++, offset += pitch) { + cm_store(base, offset, out.row(i).format()); + } +} + + +#if USE_KQ == 1 && (BLOCK_SG_M == 64 && BLOCK_SG_N == 32) +// register tile: [8, 2] aka[(8*8,16), (16, 16*2)] +// src_a is key, src_b is query +CM_INLINE void gemm_kq_64x32_xe2(uint id_wg_m, uint id_wg_n, uint hq, uint slm, svmptr_t key_cache, svmptr_t query, svmptr_t block_indices ATTR, svmptr_t block_indices_begins ATTR, svmptr_t kq_max ATTR, svmptr_t kq_max_wg ATTR, svmptr_t kq_exp_partial_sum ATTR, +uint M, uint N, uint K, uint query_stride, uint q_start_strided) { + constexpr int SG_SIZE = 16; + constexpr int BLOCK_WG_K = 64; // same in sg +#ifndef BLOCK_SG_M + #define BLOCK_SG_M 64 + #define BLOCK_SG_N 32 + #define SG_M 4 + #define SG_N 4 + #define HEAD_SIZE 128 + #define KV_BLOCK_SIZE 256 + #define STRIDE 16 +#endif + // xehpg DPAS spec: dst: [8, 8], repeat: 1~8, depth: 8 + static constexpr int REPEAT = 8; + static constexpr int DEPTH = 8; + static constexpr int BLOCK_REG_M = REPEAT; + static constexpr int BLOCK_REG_N = SG_SIZE; + static constexpr int BLOCK_DPAS_C = BLOCK_REG_M * BLOCK_REG_N; + static constexpr int VNNI = sizeof(half); + static constexpr int BLOCK_REG_K = DEPTH * sizeof(int) / VNNI; + static constexpr int BLOCK_REG_A = BLOCK_REG_M * BLOCK_REG_K; + static constexpr int BLOCK_REG_B = BLOCK_REG_N * BLOCK_REG_K; + static constexpr int BLOCK_WG_M = SG_M * BLOCK_SG_M; + static constexpr int BLOCK_WG_N = SG_N * BLOCK_SG_N; + // register blocking + static constexpr int REG_M = BLOCK_SG_M / BLOCK_REG_M; + static constexpr int REG_N = BLOCK_SG_N / BLOCK_REG_N; + static constexpr int REG_K = BLOCK_WG_K / BLOCK_REG_K; + static constexpr int REG_MN = REG_M * REG_N; + static constexpr int KEY_LINES_PER_LOAD = KV_BLOCK_SIZE / STRIDE; + + matrix acc = 0; // --> 64*2 regs + uint id_sg_n = cm_local_id(0); + uint id_sg_m = cm_local_id(1); + uint id_sg_mn = id_sg_m * SG_N + id_sg_n; + + static_assert(REG_N == 2, "block_2d_desc for b is manually unrolled by 2"); + static_assert(HEAD_SIZE % BLOCK_WG_K == 0, "K dimension must be multiple of BLOCK_WG_K"); + static_assert(KV_BLOCK_SIZE == 256, "block size of key(key_cache) should be 256"); + uint M_block = (M + BLOCK_WG_M - 1) / BLOCK_WG_M; + uint N_aligned = (N + BLOCK_WG_N - 1) / BLOCK_WG_N * BLOCK_WG_N; + uint M_block_aligned = M_block * (BLOCK_WG_M / (BLOCK_SIZE / STRIDE)); + const uint block_size_div_stride = BLOCK_SIZE / STRIDE; + constexpr half log2e = 1.4426950408889634f; + static_assert(BLOCK_SG_M / block_size_div_stride == 8, "BLOCK_SG_M / block_size_div_stride should be 8"); + static_assert(BLOCK_SG_N == 32, "BLOCK_SG_N should be 32"); + +#if IS_CAUSAL == 1 + if ((int)(id_wg_m * BLOCK_WG_M) >= ((int)id_wg_n + 1) * BLOCK_WG_N + q_start_strided) { + // fill -inf -> max in group, 0 -> exp_sum to make compensation work + { + // current max -> mem + vector max_n = -60000; + // kq_max_wg: [b, hq, M/BLOCK_WG_M, N_aligned] + uint offset = (id_wg_m * N_aligned + id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_n.format()); + } + { + // store + matrix sum_t = 0; + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, N - 1, (uint)(M_block_aligned * sizeof(half) - 1), (uint)(M_block_aligned * sizeof(half) - 1), + (int)((id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) / block_size_div_stride), (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) }; + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + } + + return; + } +#endif + // assume block index coming from 0 in block_indices_begins + int block_index_begin = ((int*)block_indices_begins)[0]; + int* block_indices_p = (int*)block_indices + block_index_begin; + int b_adjacent_between_head = query_stride / STRIDE; + // N[0:16*2]xK[0:16] + lsc::block_2d_desc desc_b0{ query, N - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head / 2, (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) }; + // prefetch B + static constexpr int SG_MN = SG_M * SG_N; + lsc::block_2d_desc desc_prefetch_b{query, N - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head, (int)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) }; + // N[0:16*2]xK[0:16] --> 8+8 regs + matrix b0, b1; + + // M[0:16]xK[0:32] + uint block_idx = (uint)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * STRIDE / KV_BLOCK_SIZE; + uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // M[16:32]xK[0:32] + offset = block_indices_p[block_idx + 1] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // M[32:48]xK[0:32] + offset = block_indices_p[block_idx + 2] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a2{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // M[48:64]xK[0:32] + offset = block_indices_p[block_idx + 3] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_a3{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // prefetch A + block_idx = (uint)(id_wg_m * BLOCK_WG_M + id_sg_mn * (BLOCK_WG_M / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + static_assert(BLOCK_WG_M / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); + lsc::block_2d_desc desc_prefetch_a{ key_cache + offset, BLOCK_WG_M / SG_MN - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // 0~2 M[:]xK[0:16] 2~4 K[16:32] --> 32 * 2 regs + matrix a0, a1, a2, a3; + + // warmup + // prefetch + cm_prefetch(desc_prefetch_b); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + cm_prefetch(desc_prefetch_a); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load b: N[0:16]xK[0:16] + cm_load(b0.row(0).format(), desc_b0); + cm_load(b0.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + cm_sbarrier(1); + + auto dot = [&](matrix_ref A0, matrix_ref A1, matrix_ref A2, matrix_ref A3, matrix_ref B) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)(reg_m * REG_N + reg_n)) = cm_dpas(acc.row((ushort)(reg_m * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A0.row((ushort)reg_m).format()); + } +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)((reg_m + 2) * REG_N + reg_n)) = cm_dpas(acc.row((ushort)((reg_m + 2) * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A1.row((ushort)reg_m).format()); + } +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)((reg_m + 4) * REG_N + reg_n)) = cm_dpas(acc.row((ushort)((reg_m + 4) * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A2.row((ushort)reg_m).format()); + } +#pragma unroll + for (uint reg_m = 0; reg_m < 2; reg_m++) { + acc.row((ushort)((reg_m + 6) * REG_N + reg_n)) = cm_dpas(acc.row((ushort)((reg_m + 6) * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A3.row((ushort)reg_m).format()); + } + } + }; + + for (uint s = 0; s < STRIDE; s++) { + #pragma unroll + for (uint hs = 0; hs < HEAD_SIZE / BLOCK_WG_K; hs++) { + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) + desc_prefetch_b.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head); + else + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + + // load b: N[0:16*2]xK[16:32] + cm_load(b1.row(0).format(), desc_b0); + cm_load(b1.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + + // load a: M[0:16*4]xK[0:32] + cm_load(a0.format(), desc_a0); + cm_load(a1.format(), desc_a1); + cm_load(a2.format(), desc_a2); + cm_load(a3.format(), desc_a3); + + desc_a0.set_block_x(desc_a0.get_block_x() + 32); + desc_a1.set_block_x(desc_a1.get_block_x() + 32); + desc_a2.set_block_x(desc_a2.get_block_x() + 32); + desc_a3.set_block_x(desc_a3.get_block_x() + 32); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(), a1.select<2, 1, BLOCK_REG_A, 1>(), + a2.select<2, 1, BLOCK_REG_A, 1>(), a3.select<2, 1, BLOCK_REG_A, 1>(), + b0); + + // load b: N[0:16*2]xK[32:48] + cm_load(b0.row(0).format(), desc_b0); + cm_load(b0.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(2), a1.select<2, 1, BLOCK_REG_A, 1>(2), + a2.select<2, 1, BLOCK_REG_A, 1>(2), a3.select<2, 1, BLOCK_REG_A, 1>(2), + b1); + + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load b: N[0:16*2]xK[48:64] + cm_load(b1.row(0).format(), desc_b0); + cm_load(b1.row(1).format(), desc_b0); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) + desc_b0.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head / 2); + else + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + // load a: M[0:32]xK[32:64] + cm_load(a0.format(), desc_a0); + cm_load(a1.format(), desc_a1); + cm_load(a2.format(), desc_a2); + cm_load(a3.format(), desc_a3); + desc_a0.set_block_x(desc_a0.get_block_x() + 32); + desc_a1.set_block_x(desc_a1.get_block_x() + 32); + desc_a2.set_block_x(desc_a2.get_block_x() + 32); + desc_a3.set_block_x(desc_a3.get_block_x() + 32); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(), a1.select<2, 1, BLOCK_REG_A, 1>(), + a2.select<2, 1, BLOCK_REG_A, 1>(), a3.select<2, 1, BLOCK_REG_A, 1>(), + b0); + + // load b: N[0:16*4]xK[0:16] + cm_load(b0.row(0).format(), desc_b0); + cm_load(b0.row(1).format(), desc_b0); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + + dot(a0.select<2, 1, BLOCK_REG_A, 1>(2), a1.select<2, 1, BLOCK_REG_A, 1>(2), + a2.select<2, 1, BLOCK_REG_A, 1>(2), a3.select<2, 1, BLOCK_REG_A, 1>(2), + b1); + cm_sbarrier(0); + cm_sbarrier(1); + } + } + + cm_sbarrier(0); + + matrix acc_half; + union { + float f; + int y; + } i2f; + i2f.y = INV_S; + const float inv_s = i2f.f; +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M; reg_m++) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { + acc_half.select(reg_m * BLOCK_REG_M, reg_n * BLOCK_REG_N) = + acc.row(reg_m * REG_N + reg_n) * inv_s; + } + } + + // if N(aka query) has tails, the following will not change the accuracy: + // gemm will compute results for the padding N(all should be zeros), the kq_max/kq_max_wg/kq_exp_partial_sum are along the query dimension and + // the results can be dropped in the future stage. To simplify the logic, the size of kq_max/kq_max_wg/kq_exp_partial_sum must be enough to hold + // all tails + padding results. + int m_start = (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M); + m_start = MYMIN(m_start, M); + int m_end = MYMIN(m_start + BLOCK_SG_M, M); + int valid_m = m_end - m_start; + matrix sum_t; + vector seq; + cmtl::cm_vector_assign(seq.select_all(), 0, 1); +#if IS_CAUSAL == 1 + bool skip_mask = false; + // in streaming scenario, the past kvcache length may be arbitrary so valid causal mask of a workgroup may start at arbitrary position + if (m_end <= (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + q_start_strided)) { + // all are inside causal mask == 1 + skip_mask = true; + } else { + vector n_pos = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + q_start_strided) + seq; + #pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + SIMD_IF_BEGIN (m_start + reg_m > n_pos) { + acc_half.row(reg_m) = half{-60000}; + } SIMD_IF_END; + } + } +#else + bool skip_mask = true; +#endif + // case for valid_m == BLOCK_SG_M but skip_mask == false which needs to handle causal mask: + // query = 128 * 2 + 1, key = 256 * 2 + if (valid_m == BLOCK_SG_M && skip_mask) { + vector max_n = acc_half.row(0); + #pragma unroll + for (uint reg_m = 1; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + max_n = cm_max(max_n, acc_half.row(reg_m)); + } + + { + uint slm_offset = (id_sg_m * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + // current max -> slm + cm_slm_block_write(slm, slm_offset, max_n.format()); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_barrier(); + // max inside wg + cm_slm_block_read(slm, id_sg_n * BLOCK_SG_N * (uint)sizeof(half), max_n.format()); + vector tmp; + #pragma unroll + for (uint i = 1; i < SG_M; i++) { + slm_offset = (i * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + cm_slm_block_read(slm, slm_offset, tmp.format()); + max_n = cm_max(max_n, tmp); + } + // max across wg + // kq_max: [b, hq, N_aligned] + vector max_offsets = (id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + seq) * (uint)sizeof(half); + cm_ptr_atomic((half*)kq_max, max_offsets, max_n); + + // current max -> mem + // kq_max_wg: [b, hq, M/BLOCK_WG_M, N_aligned] + uint offset = (id_wg_m * N_aligned + id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_n.format()); + } + { + // kq_exp_partial_sum: [b, hq, N_aligned, M/(BLOCK_SIZE/STRIDE)] + matrix sum; + #pragma unroll + for (uint m = 0; m < BLOCK_SG_M / block_size_div_stride; m++) { + sum.row(m) = cm_exp((acc_half.row(m * block_size_div_stride) - max_n) * log2e); + #pragma unroll + for (uint sub_m = 1; sub_m < block_size_div_stride; sub_m++) { + uint real_m = m * block_size_div_stride + sub_m; + sum.row(m) += cm_exp((acc_half.row(real_m) - max_n) * log2e); + } + } + + Transpose_8x32(sum, sum_t); + } + } else { + // M tails + vector max_n = -60000; + for (uint reg_m = 0; reg_m < valid_m; reg_m++) { + max_n = cm_max(max_n, acc_half.row(reg_m)); + } + + { + uint slm_offset = (id_sg_m * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + // current max -> slm + cm_slm_block_write(slm, slm_offset, max_n.format()); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_barrier(); + // max inside wg + cm_slm_block_read(slm, id_sg_n * BLOCK_SG_N * (uint)sizeof(half), max_n.format()); + vector tmp; + #pragma unroll + for (uint i = 1; i < SG_M; i++) { + slm_offset = (i * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * (uint)sizeof(half); + cm_slm_block_read(slm, slm_offset, tmp.format()); + max_n = cm_max(max_n, tmp); + } + // max across wg + // kq_max: [b, hq, N_aligned] + vector max_offsets = (id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + seq) * (uint)sizeof(half); + cm_ptr_atomic((half*)kq_max, max_offsets, max_n); + + // current max -> mem + // kq_max_wg: [b, hq, M/BLOCK_WG_M, N_aligned] + uint offset = (id_wg_m * N_aligned + id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * sizeof(half); + cm_ptr_store((int*)kq_max_wg, offset, max_n.format()); + } + { + // kq_exp_partial_sum: [b, hq, N_aligned, M/(BLOCK_SIZE/STRIDE)] + matrix sum = 0; + vector n_pos = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N + q_start_strided) + seq; + #pragma unroll + for (uint m = 0; m < BLOCK_SG_M / block_size_div_stride; m++) { + #pragma unroll + for (uint sub_m = 0; sub_m < block_size_div_stride; sub_m++) { + uint real_m = m * block_size_div_stride + sub_m; +#if IS_CAUSAL == 1 + // to following case: + // 0 0 1 1 + // 0 0 0 1 + // the acc value of first column should be -inf --> max(first column) == -inf --> exp(first column - max) == 1, this is incorrect + // so need to use simd_if to detect per element state + SIMD_IF_BEGIN ((m_start + real_m <= n_pos) & (real_m < valid_m)) { + sum.row(m) += cm_exp((acc_half.row(real_m) - max_n) * log2e); + } SIMD_IF_END; +#else + if (real_m < valid_m) + sum.row(m) += cm_exp((acc_half.row(real_m) - max_n) * log2e); +#endif + } + } + Transpose_8x32(sum, sum_t); + } + } + // store + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, N - 1, (uint)(M_block_aligned * sizeof(half) - 1), (uint)(M_block_aligned * sizeof(half) - 1), + (int)((id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) / block_size_div_stride), (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) }; + cm_store(desc_c, sum_t.select<8, 1, 8, 1>( 0).format()); + cm_store(desc_c, sum_t.select<8, 1, 8, 1>( 8).format()); + cm_store(desc_c, sum_t.select<8, 1, 8, 1>(16).format()); + cm_store(desc_c, sum_t.select<8, 1, 8, 1>(24).format()); +} +#endif + +#if 1 || (BLOCK_SG_M == 64 && BLOCK_SG_N == 32) +// const static int channels_reduce_32[] = { 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, +// 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31}; +// src_a is query, src_b is key +CM_INLINE void gemm_qk_64x32_xe2(uint id_wg_m, uint id_wg_n, uint hq, uint slm, svmptr_t key_cache ATTR, svmptr_t query ATTR, svmptr_t block_indices ATTR, svmptr_t block_indices_begins ATTR, svmptr_t kq_max_wg ATTR, svmptr_t kq_exp_partial_sum ATTR, +uint M, uint N, uint K, uint query_stride, uint q_start_strided) { + constexpr int SG_SIZE = 16; + constexpr int BLOCK_WG_K = 64; // same in sg +#ifndef BLOCK_SG_M + #define BLOCK_SG_M 64 + #define BLOCK_SG_N 32 + #define SG_M 4 + #define SG_N 4 + #define HEAD_SIZE 128 + #define KV_BLOCK_SIZE 256 + #define STRIDE 16 +#endif + // xehpg DPAS spec: dst: [8, 8], repeat: 1~8, depth: 8 + static constexpr int REPEAT = 8; + static constexpr int DEPTH = 8; + static constexpr int BLOCK_REG_M = REPEAT; + static constexpr int BLOCK_REG_N = SG_SIZE; + static constexpr int BLOCK_DPAS_C = BLOCK_REG_M * BLOCK_REG_N; + static constexpr int VNNI = sizeof(half); + static constexpr int BLOCK_REG_K = DEPTH * sizeof(int) / VNNI; + static constexpr int BLOCK_REG_A = BLOCK_REG_M * BLOCK_REG_K; + static constexpr int BLOCK_REG_B = BLOCK_REG_N * BLOCK_REG_K; + static constexpr int BLOCK_WG_M = SG_M * BLOCK_SG_M; + static constexpr int BLOCK_WG_N = SG_N * BLOCK_SG_N; + // register blocking + static constexpr int REG_M = BLOCK_SG_M / BLOCK_REG_M; + static constexpr int REG_N = BLOCK_SG_N / BLOCK_REG_N; + static constexpr int REG_K = BLOCK_WG_K / BLOCK_REG_K; + static constexpr int REG_MN = REG_M * REG_N; + static constexpr int KEY_LINES_PER_LOAD = KV_BLOCK_SIZE / STRIDE; + + matrix acc = 0; // --> 64*2 regs + uint id_sg_n = cm_local_id(0); + uint id_sg_m = cm_local_id(1); + uint id_sg_mn = id_sg_m * SG_N + id_sg_n; + + static_assert(REG_N == 2, "block_2d_desc for b is manually unrolled by 2"); + static_assert(HEAD_SIZE % BLOCK_WG_K == 0, "K dimension must be multiple of BLOCK_WG_K"); + static_assert(KV_BLOCK_SIZE == 256, "block size of key(key_cache) should be 256"); + uint N_block = (N + BLOCK_WG_N - 1) / BLOCK_WG_N; + uint M_aligned = (M + BLOCK_WG_M - 1) / BLOCK_WG_M * BLOCK_WG_M; + uint K_block_pad = N_block * (BLOCK_WG_N / (BLOCK_SIZE / STRIDE)); + const uint block_size_div_stride = BLOCK_SIZE / STRIDE; + constexpr SOFTMAX_TYPE log2e = 1.4426950408889634f; + //static_assert(BLOCK_SG_M / block_size_div_stride == 8, "BLOCK_SG_M / block_size_div_stride should be 8"); + +#if IS_CAUSAL == 1 + if ((int)(id_wg_n * BLOCK_WG_N) >= ((int)id_wg_m + 1) * BLOCK_WG_M + q_start_strided) { + // fill -inf -> max in group, 0 -> exp_sum to make compensation work + { + // current max -> mem + vector max_m = -60000; + // kq_max_wg: [b, hq, N/BLOCK_WG_N, M_aligned] + uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(SOFTMAX_TYPE); + cm_ptr_store((int*)kq_max_wg, offset, max_m.format()); + } + { + // store + matrix sum_t = 0; + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), + (int)((id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) / block_size_div_stride), (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + cm_store(desc_c, sum_t.format()); + } + + return; + } +#endif + // assume block index coming from 0 in block_indices_begins + int block_index_begin = ((int*)block_indices_begins)[0]; + int* block_indices_p = (int*)block_indices + block_index_begin; + int b_adjacent_between_head = query_stride / STRIDE; + // M[0:16*2]xK[0:16] + lsc::block_2d_desc desc_a{ query, M - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head, (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; + // prefetch A + static constexpr int SG_MN = SG_M * SG_N; + lsc::block_2d_desc desc_prefetch_a{query, M - 1, (uint)((query_stride - hq * HEAD_SIZE) * sizeof(half) - 1), (uint)(query_stride * sizeof(half) - 1), + (STRIDE - 1) * b_adjacent_between_head, (int)(id_wg_m * BLOCK_WG_M + id_sg_mn * (BLOCK_WG_M / SG_MN)) }; + // M[0:16*2]xK[0:16] --> 8+8 regs + matrix a0; + + // M[0:16]xK[0:32] + uint block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) * STRIDE / KV_BLOCK_SIZE; + uint max_block_idx = (uint)(N * STRIDE + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE - 1; + block_idx = MYMIN(block_idx, max_block_idx); +#if USE_INT8 + uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); + lsc::block_2d_desc desc_b0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), + 0, 0 }; + uint scale_offset0 = offset + KV_BLOCK_SIZE * HEAD_SIZE; + block_idx = MYMIN(block_idx + 1, max_block_idx); + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); + lsc::block_2d_desc desc_b1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), + 0, 0 }; + uint scale_offset1 = offset + KV_BLOCK_SIZE * HEAD_SIZE; + // prefetch B + block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + block_idx = MYMIN(block_idx, max_block_idx); + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); + static_assert(BLOCK_WG_N / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); + lsc::block_2d_desc desc_prefetch_b{ key_cache + offset, BLOCK_WG_N / SG_MN - 1, (uint)(K * sizeof(char) - 1), (uint)(K * sizeof(char) - 1), + 0, 0 }; + + // N[:]xK[0:32] --> 16 * 1 regs + matrix b0_up_s8, b0_down_s8, b1_up_s8, b1_down_s8; + matrix b0; // --> 16 regs + matrix scales, zps; + matrix scales_block, zps_block; +#else + uint offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_b0{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // printf("===============0 lid:%d.%d, block_idx=%d, offset=%u\n", id_sg_n, id_sg_m, block_idx, offset); + block_idx = MYMIN(block_idx + 1, max_block_idx); + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + lsc::block_2d_desc desc_b1{ key_cache + offset, KEY_LINES_PER_LOAD - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // printf("===============1 lid:%d.%d, block_idx=%d, offset=%u\n", id_sg_n, id_sg_m, block_idx, offset); + // prefetch B + block_idx = (uint)(id_wg_n * BLOCK_WG_N + id_sg_mn * (BLOCK_WG_N / SG_MN)) * STRIDE / KV_BLOCK_SIZE; + block_idx = MYMIN(block_idx, max_block_idx); + offset = block_indices_p[block_idx] * (HK * KV_BLOCK_SIZE * HEAD_SIZE * (uint)sizeof(half)); + static_assert(BLOCK_WG_N / SG_MN <= KEY_LINES_PER_LOAD, "prefetch lines should be inside one block"); + lsc::block_2d_desc desc_prefetch_b{ key_cache + offset, BLOCK_WG_N / SG_MN - 1, (uint)(K * sizeof(half) - 1), (uint)(K * sizeof(half) - 1), + 0, 0 }; + // printf("===============2 lid:%d.%d, block_idx=%d, offset=%u\n", id_sg_n, id_sg_m, block_idx, offset); + // 0~2 M[:]xK[0:16] 2~4 K[16:32] --> 32 * 2 regs + matrix b0, b1; +#endif + + // warmup + // prefetch + cm_prefetch(desc_prefetch_b); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + cm_prefetch(desc_prefetch_a); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load b: N[0:16]xK[0:16] +#if USE_INT8 + { + lsc::block_2d_desc desc_scale{ key_cache + scale_offset0, 16 * 2 - 1, (uint)(16 * sizeof(half) - 1), (uint)(16 * sizeof(half) - 1), + 0, 0 }; + matrix tmp_scale, tmp_zp; + cm_load(tmp_scale.format(), desc_scale); + cm_load(tmp_zp.format(), desc_scale); + scales_block[0].format().select<8, 2, 16, 1>(0) = tmp_scale.format().select<8, 1, 16, 2>(0, 0); + scales_block[0].format().select<8, 2, 16, 1>(1) = tmp_scale.format().select<8, 1, 16, 2>(0, 1); + zps_block[0].format().select<8, 2, 16, 1>(0) = tmp_zp.format().select<8, 1, 16, 2>(0, 0); + zps_block[0].format().select<8, 2, 16, 1>(1) = tmp_zp.format().select<8, 1, 16, 2>(0, 1); + desc_scale.set_base(key_cache + scale_offset1); + cm_load(tmp_scale.format(), desc_scale); + cm_load(tmp_zp.format(), desc_scale); + scales_block[1].format().select<8, 2, 16, 1>(0) = tmp_scale.format().select<8, 1, 16, 2>(0, 0); + scales_block[1].format().select<8, 2, 16, 1>(1) = tmp_scale.format().select<8, 1, 16, 2>(0, 1); + zps_block[1].format().select<8, 2, 16, 1>(0) = tmp_zp.format().select<8, 1, 16, 2>(0, 0); + zps_block[1].format().select<8, 2, 16, 1>(1) = tmp_zp.format().select<8, 1, 16, 2>(0, 1); + } + + cm_load(b0_up_s8.format(), desc_b0); + cm_load(b0_down_s8.format(), desc_b1); +#else + cm_load(b0[0].format(), desc_b0); + cm_load(b0[1].format(), desc_b1); +#endif + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + + cm_sbarrier(1); + +#if USE_INT8 + auto dec = [&](vector B0_i8, vector B1_i8, matrix_ref B0) { +#pragma unroll + for (int n = 0; n < REG_N; n++) { + auto b = B0[n].format(); +#pragma unroll + for (int m = 0; m < 8; m++) { + auto b_row = b[m]; + vector d0; + if (n == 0) + d0 = B0_i8.format()[m / 2].select<16, 2>(m % 2); + else + d0 = B1_i8.format()[m / 2].select<16, 2>(m % 2); + b_row.format() = d0.format(); + b_row *= half{32768.0}; + b_row *= half{512.0}; + b_row = (b_row - zps[n]) * scales[n]; + } + } + }; +#endif + auto dot = [&](matrix A, matrix B) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M; reg_m++) { + acc.row((ushort)(reg_m * REG_N + reg_n)) = cm_dpas(acc.row((ushort)(reg_m * REG_N + reg_n)), + B.row((ushort)reg_n).format(), A.row((ushort)reg_m).format()); + } + } + }; + + for (uint s = 0; s < STRIDE; s++) { +#if USE_INT8 + auto tmp = scales_block[0].select<16, 1>(s * 16); + scales[0].select<16, 2>(0) = tmp; + scales[0].select<16, 2>(1) = scales[0].select<16, 2>(0); + tmp = scales_block[1].select<16, 1>(s * 16); + scales[1].select<16, 2>(0) = tmp; + scales[1].select<16, 2>(1) = scales[1].select<16, 2>(0); + tmp = zps_block[0].select<16, 1>(s * 16); + zps[0].select<16, 2>(0) = tmp; + zps[0].select<16, 2>(1) = zps[0].select<16, 2>(0); + tmp = zps_block[1].select<16, 1>(s * 16); + zps[1].select<16, 2>(0) = tmp; + zps[1].select<16, 2>(1) = zps[1].select<16, 2>(0); +#endif + #pragma unroll + for (uint hs = 0; hs < HEAD_SIZE / BLOCK_WG_K; hs++) { + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) + desc_prefetch_a.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head); + else + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load a: M[0:16*4]xK[0:16] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + // load b: N[0:16*2]xK[16:32] +#if USE_INT8 + cm_load(b1_up_s8.format(), desc_b0); + cm_load(b1_down_s8.format(), desc_b1); + dec(b0_up_s8.format().select<64, 1>(), b0_down_s8.format().select<64, 1>(), b0); + dot(a0, b0); +#else + cm_load(b1[0].format(), desc_b0); + cm_load(b1[1].format(), desc_b1); + dot(a0, b0); +#endif + + desc_a.set_block_x(desc_a.get_block_x() + 16); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + + // load a: M[0:16*4]xK[16:32] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + +#if USE_INT8 + dec(b0_up_s8.format().select<64, 1>(64), b0_down_s8.format().select<64, 1>(64), b0); + dot(a0, b0); +#else + cm_load(b0[0].format(), desc_b0); + cm_load(b0[1].format(), desc_b1); + dot(a0, b1); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); +#endif + desc_a.set_block_x(desc_a.get_block_x() + 16); + + // prefetch + cm_prefetch(desc_prefetch_b); + cm_prefetch(desc_prefetch_a); + desc_prefetch_b.set_block_x(desc_prefetch_b.get_block_x() + 32); + desc_prefetch_a.set_block_x(desc_prefetch_a.get_block_x() + 32); + + // load a: M[0:16*4]xK[32:48] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + + // load b: N[0:16*2]xK[32:64] +#if USE_INT8 + cm_load(b0_up_s8.format(), desc_b0); + cm_load(b0_down_s8.format(), desc_b1); + dec(b1_up_s8.format().select<64, 1>(), b1_down_s8.format().select<64, 1>(), b0); + dot(a0, b0); +#else + cm_load(b1[0].format(), desc_b0); + cm_load(b1[1].format(), desc_b1); + dot(a0, b0); +#endif + + desc_a.set_block_x(desc_a.get_block_x() + 16); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + + // load a: M[0:16*4]xK[48:64] + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(0).format(), desc_a); + cm_load(a0.select<4, 1, BLOCK_REG_A, 1>(4).format(), desc_a); + if (hs == HEAD_SIZE / BLOCK_WG_K - 1) { + desc_a.set_block_x((STRIDE - 1 - s - 1) * b_adjacent_between_head); + } else { + desc_a.set_block_x(desc_a.get_block_x() + 16); + } + +#if USE_INT8 + dec(b1_up_s8.format().select<64, 1>(64), b1_down_s8.format().select<64, 1>(64), b0); + dot(a0, b0); +#else + cm_load(b0[0].format(), desc_b0); + cm_load(b0[1].format(), desc_b1); + desc_b0.set_block_x(desc_b0.get_block_x() + 8); + desc_b1.set_block_x(desc_b1.get_block_x() + 8); + dot(a0, b1); +#endif + + cm_sbarrier(0); + cm_sbarrier(1); + } + } + + cm_sbarrier(0); + + matrix acc_half; + union { + float f; + int y; + } i2f; + i2f.y = INV_S; + const float inv_s = i2f.f; +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M; reg_m++) { +#pragma unroll + for (int reg_n = 0; reg_n < REG_N; reg_n++) { + acc_half.select(reg_m * BLOCK_REG_M, reg_n * BLOCK_REG_N) = + acc.row(reg_m * REG_N + reg_n) * inv_s; + } + } + + // if M(aka query) has tails, the following will not change the accuracy: + // gemm will compute results for the padding M(all should be zeros), the kq_max/kq_max_wg/kq_exp_partial_sum are along the query dimension and + // the results can be dropped in the future stage. To simplify the logic, the size of kq_max/kq_max_wg/kq_exp_partial_sum must be enough to hold + // all tails + padding results. + int n_start = (int)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N); + n_start = MYMIN(n_start, N); + int n_end = MYMIN(n_start + BLOCK_SG_N, N); + int valid_n = n_end - n_start; + matrix sum_t; + vector seq_m; + cmtl::cm_vector_assign(seq_m.select_all(), 0, 1); + vector_ref seq = seq_m.select(); + vector n_pos = (uint)(id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) + seq; +#if IS_CAUSAL == 1 + bool skip_mask = false; + int m_start = (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M + q_start_strided); + // in streaming scenario, the past kvcache length may be arbitrary so valid causal mask of a workgroup may start at arbitrary position + if (n_end <= m_start) { + // all are inside causal mask == 1 + skip_mask = true; + } else { + #pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + SIMD_IF_BEGIN (n_pos > m_start + reg_m) { + acc_half.row(reg_m) = SOFTMAX_TYPE{-60000}; + } SIMD_IF_END; + } + } +#else + bool skip_mask = true; +#endif + vector max_m; + if (valid_n != BLOCK_SG_N) { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + acc_half.row(reg_m).merge(SOFTMAX_TYPE{-60000}, n_pos >= N); + } + } + max_m.select<32, 1>() = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>()).format(); + max_m.select<32, 1>(32) = reduce2d<1, 0, 1>(acc_half.select<32, 1, 32, 1>(32)).format(); + + { + uint slm_offset = (id_sg_n * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(SOFTMAX_TYPE); + // current max -> slm + cm_slm_block_write(slm, slm_offset, max_m.format()); + cm_slm_fence(CM_LOCAL_BARRIER); + cm_barrier(); + // max inside wg + cm_slm_block_read(slm, id_sg_m * BLOCK_SG_M * (uint)sizeof(SOFTMAX_TYPE), max_m.format()); + vector tmp; +#pragma unroll + for (uint i = 1; i < SG_N; i++) { + slm_offset = (i * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * (uint)sizeof(SOFTMAX_TYPE); + cm_slm_block_read(slm, slm_offset, tmp.format()); + max_m = cm_max(max_m, tmp); + } + // max across wg + // kq_max: [b, hq, M_aligned] + //auto max_offsets = (id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M + seq_m) * (uint)sizeof(half); + //cm_ptr_atomic((half*)kq_max, max_offsets, max_m); + + // current max -> mem + // kq_max_wg: [b, hq, N/BLOCK_WG_N, M_aligned] + uint offset = (id_wg_n * M_aligned + id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) * sizeof(SOFTMAX_TYPE); + cm_ptr_store((int*)kq_max_wg, offset, max_m.format()); + } + { + // kq_exp_partial_sum: [b, hq, M_aligned, N/(BLOCK_SIZE/STRIDE)] + if (valid_n == BLOCK_SG_N && skip_mask) { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + acc_half.row(reg_m) = cm_exp((acc_half.row(reg_m) - max_m[reg_m]) * log2e); + } + } else { +#pragma unroll + for (uint reg_m = 0; reg_m < REG_M * BLOCK_REG_M; reg_m++) { + acc_half.row(reg_m) = cm_exp((acc_half.row(reg_m) - max_m[reg_m]) * log2e); + // causal mask in the following case: + // block0(EU0) block1(EU1) + // 1 1 1 1 0 0 0 0 + // 1 1 1 1 1 0 0 0 + // 1 1 1 1 1 1 0 0 + // 1 1 1 1 1 1 1 0 + // the acc value of first row of block1 should be -inf --> max(first column) == -inf --> exp(first column - max) == 1, this is incorrect + // so need to use simd_if to detect per element state +#if IS_CAUSAL + SIMD_IF_BEGIN ((n_pos > m_start + reg_m) | (n_pos >= N)) { +#else + SIMD_IF_BEGIN (n_pos >= N) { +#endif + acc_half.row(reg_m) = 0; + } SIMD_IF_END; + } + } + sum_t.select<32, 1, 4, 1>( 0).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>( 0)).format(); + sum_t.select<32, 1, 4, 1>(32).format() = reduce2d<4, 1, 4>(acc_half.select<32, 1, 32, 1>(32)).format(); + } + // store + lsc::block_2d_desc desc_c{ kq_exp_partial_sum, M - 1, (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(K_block_pad * sizeof(SOFTMAX_TYPE) - 1), + (int)((id_wg_n * BLOCK_WG_N + id_sg_n * BLOCK_SG_N) / block_size_div_stride), (int)(id_wg_m * BLOCK_WG_M + id_sg_m * BLOCK_SG_M) }; + cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 0).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>( 8).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(16).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(24).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(32).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(40).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(48).format()); + cm_store(desc_c, sum_t.select<8, 1, 4, 1>(56).format()); +} +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp new file mode 100644 index 00000000000000..a4c57a19d1e8e7 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/find_block.hpp @@ -0,0 +1,247 @@ +/* + * Copyright (c) 2020-2023, Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +#include +#include + +#include "sort.hpp" +#define MYMIN(x, y) ((x) < (y) ? (x) : (y)) + +#define MYCONCAT(x, y) x ## y +#define IS_float 1 +#define IS_half 2 +#define CUR_TYPE_(a) MYCONCAT(IS_, a) +#define CUR_TYPE CUR_TYPE_(SOFTMAX_TYPE) + +// kq_max_wg: [b, hq, n_groups, q_stride_pad] +// kq_exp_partial_sum: [b, hq, q_stride_pad, k_block_pad] +// kq_sum: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] +CM_INLINE void find(uint slm, int m_block, svmptr_t kq_max_wg, svmptr_t kq_exp_partial_sum, svmptr_t block_mask, uint q_len, uint q_stride, uint q_stride_pad, uint k_block_pad, float thresh, uint causal_start_index +#if DEBUG_ACC == 1 + , svmptr_t kq_sum +#endif +) { + constexpr int SG_SIZE = 16; +#ifndef BLOCK_SG_M + #define BLOCK_SG_M 64 + #define BLOCK_SG_N 32 + #define SG_M 2 + #define SG_N 4 + #define HEAD_SIZE 128 + #define KV_BLOCK_SIZE 256 + #define STRIDE 16 + #define BLOCK_SIZE 128 + #define BLOCK_SHARE_MAX 256 +#endif + + const int TOKEN_IN_BLOCK = (BLOCK_SIZE / STRIDE); + int m = m_block * TOKEN_IN_BLOCK; + vector max_m; + + const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; + kq_exp_partial_sum += m * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + kq_max_wg += m * (int)sizeof(SOFTMAX_TYPE); + constexpr SOFTMAX_TYPE log2e = 1.4426950408889634f; + matrix sum_m = 0; + matrix data; + int m_start = MYMIN(m, q_stride); + int m_end = MYMIN(m_start + TOKEN_SHARE_MAX, q_stride); + int valid_m = m_end - m_start; + block_mask += m_block * k_block_pad; + if (valid_m == 0) { + // case for tails: q is not inside mask, aka q % BLOCK_SIZE < STRIDE + if (m * STRIDE < q_len) { + vector one = 1; + for (int j = 0; j < k_block_pad; j += TOKEN_SHARE_MAX) { + cm_ptr_store((int*)block_mask, j, one.format()); + } + } + return; + } + lsc::block_2d_desc desc_sum{ kq_exp_partial_sum, (uint)valid_m - 1, (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), (uint)(k_block_pad * sizeof(SOFTMAX_TYPE) - 1), + 0, 0 }; + { + // find max: (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad + max_m = SOFTMAX_TYPE{-60000}; + + for (int idx = 0; idx < k_block_pad / TOKEN_SHARE_MAX; idx++) { + vector max_m_in_group; + max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(SOFTMAX_TYPE)); + max_m = cm_max(max_m, max_m_in_group); + } + } + // compensation: val*exp(local - global) + desc_sum.set_block_x(0); + for (int j = 0, idx = 0; j < k_block_pad; j += TOKEN_SHARE_MAX, idx++) { + vector max_m_in_group; + max_m_in_group.format() = cm_ptr_load((int*)kq_max_wg, q_stride_pad * idx * (int)sizeof(SOFTMAX_TYPE)); +#if CUR_TYPE == IS_float + cm_load(data.select(0, 0).format(), desc_sum); + cm_load(data.select(0, TOKEN_SHARE_MAX / 2).format(), desc_sum); +#else + cm_load(data.format(), desc_sum); +#endif + for (int i = 0; i < TOKEN_IN_BLOCK; i++) { + if (i < valid_m) { + data.row(i) *= cm_exp((max_m_in_group[i] - max_m[i]) * log2e); + sum_m.row(i) += data.row(i); + } + } +#if CUR_TYPE == IS_float + cm_store(desc_sum, data.select(0, 0).format()); + cm_store(desc_sum, data.select(0, TOKEN_SHARE_MAX / 2).format()); +#else + cm_store(desc_sum, data.format()); +#endif + desc_sum.set_block_x(desc_sum.get_block_x() + TOKEN_SHARE_MAX); + } + + // exp/sum + vector inv_sum_v; + for (int i = 0; i < TOKEN_IN_BLOCK; i++) { + if (i < valid_m) + inv_sum_v[i] = 1.0f / cm_sum(sum_m.row(i)); + else + inv_sum_v[i] = 0; + } + // compensation: sum(val*inv_sum_v) + vector sum_m_after_add = 0; + desc_sum.set_block_x(0); +#if DEBUG_ACC == 1 + kq_sum += m_block * k_block_pad * (int)sizeof(half); +#endif + vector zero = 0; + for (int j = 0; j < k_block_pad; j += TOKEN_SHARE_MAX) { +#if CUR_TYPE == IS_float + cm_load(data.select(0, 0).format(), desc_sum); + cm_load(data.select(0, TOKEN_SHARE_MAX / 2).format(), desc_sum); +#else + cm_load(data.format(), desc_sum); +#endif + data.row(0) *= inv_sum_v[0]; + for (int i = 1; i < TOKEN_IN_BLOCK; i++) { + data.row(0) += data.row(i) * inv_sum_v[i]; + } + desc_sum.set_block_x(desc_sum.get_block_x() + TOKEN_SHARE_MAX); + // the sum type is always half in the reference code of the paper: https://github.com/mit-han-lab/x-attention/blob/fb2ac200a23d20568f7d166ddb5ee247926d2b2b/xattn/src/kernels.py#L248 + vector data_half = data.row(0); + sum_m_after_add += data_half; + cm_ptr_store((int*)kq_exp_partial_sum, j * (int)sizeof(half), data_half.format()); +#if DEBUG_ACC == 1 + cm_ptr_store((int*)kq_sum, j * (int)sizeof(half), data_half.format()); +#endif + cm_ptr_store((int*)block_mask, j, zero.format()); + } + auto thresh_act = cm_sum(sum_m_after_add) * thresh; + + // content of 8(aka stride) lines: + // line 0: score + // line 1: sorted value + // line 3: sorted index + // line 5: sorted tmp + // line 6: accumalative score + auto score = kq_exp_partial_sum + 0 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto sorted_value = kq_exp_partial_sum + 1 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto sorted_index = kq_exp_partial_sum + 3 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto sorted_tmp = kq_exp_partial_sum + 5 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + auto acc_score = kq_exp_partial_sum + 6 * k_block_pad * (int)sizeof(SOFTMAX_TYPE); + +#if IS_CAUSAL == 1 + auto score_p = (half*)score; + half s_0 = score_p[0]; + half s_causal = score_p[causal_start_index + m_block]; + float s_sum = s_0; + if (causal_start_index + m_block) s_sum += s_causal; + score_p[0] = -1; + score_p[causal_start_index + m_block] = -1; + sort(slm, score, sorted_value + 2 * sizeof(half), sorted_index + 2 * sizeof(half), sorted_tmp, k_block_pad); + uchar* block_mask_p = (uchar*)block_mask; + auto sorted_value_p = (half*)sorted_value; + auto sorted_index_p = (ushort*)sorted_index; + auto acc_score_p = (half*)acc_score; + sorted_value_p[0] = 0; + sorted_value_p[1] = s_sum; + block_mask_p[0] = 1; + block_mask_p[causal_start_index + m_block] = 1; + float sum_cur = s_sum; +#if DEBUG_ACC == 1 + acc_score_p[0] = 0; + acc_score_p[1] = 0; +#endif + int j; + for (j = 2; j < k_block_pad; j++) { +#if DEBUG_ACC == 1 + acc_score_p[j] = sum_cur; +#endif + if (sum_cur < thresh_act) { + auto k_idx = sorted_index_p[j]; + if (k_idx <= causal_start_index + m_block) + block_mask_p[k_idx] = 1; + } else { + break; + } + sum_cur += sorted_value_p[j]; + } +#if DEBUG_ACC == 1 + for (; j < k_block_pad; j++) { + acc_score_p[j] = sum_cur; + sum_cur += sorted_value_p[j]; + } +#endif + + // for (int j = causal_start_index + m_block + 1; j < k_block_pad; j++) + // block_mask_p[j] = 0; + +#else + + sort(slm, score, sorted_value, sorted_index, sorted_tmp, k_block_pad); + uchar* block_mask_p = (uchar*)block_mask; + auto sorted_value_p = (half*)sorted_value; + auto sorted_index_p = (ushort*)sorted_index; + auto acc_score_p = (half*)acc_score; + block_mask_p[0] = 1; + float sum_cur = 0; +#if DEBUG_ACC == 1 + acc_score_p[0] = 0; +#endif + int j; + for (j = 0; j < k_block_pad - 1; j++) { + sum_cur += sorted_value_p[j]; +#if DEBUG_ACC == 1 + acc_score_p[j + 1] = sum_cur; +#endif + if (sum_cur < thresh_act) { + block_mask_p[sorted_index_p[j]] = 1; + } else { + block_mask_p[sorted_index_p[j]] = 1; + break; + } + } +#if DEBUG_ACC == 1 + for (j = j + 1; j < k_block_pad - 1; j++) { + sum_cur += sorted_value_p[j]; + acc_score_p[j + 1] = sum_cur; + } +#endif + +#endif +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp new file mode 100644 index 00000000000000..97f680e5cba441 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/sort.hpp @@ -0,0 +1,233 @@ +/* + * Copyright (c) 2020-2023, Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +#include +#include + +#ifndef ATTR +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +#if 0 +template +void show(const vector mat, bool is_hex=true) { + printf("vector [%d]:\n[", N); + for(int n = 0; n < N; n ++) { + if (is_hex) + printf("%x,", mat[n]); + else + printf("%d,", mat[n]); + } + printf("]\n"); +} + +template +void show(const matrix mat) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} +#endif + +// https://gpuopen.com/download/Introduction_to_GPU_Radix_Sort.pdf +template +CM_INLINE void sort(uint slm, svmptr_t src, svmptr_t sorted_value, svmptr_t sorted_index, svmptr_t sort_tmp, uint n) { + const ushort THREADS = 32; + vector seq_u32; + vector seq; + cmtl::cm_vector_assign(seq.select_all(), 0, 1); + cmtl::cm_vector_assign(seq_u32.select_all(), 0, 1); + svmptr_t sorted_src = src; + svmptr_t sorted_tmp = sorted_value; + svmptr_t cur_idx = sorted_index; + svmptr_t cur_idx_tmp = sort_tmp; + uint iter = (n + THREADS - 1) / THREADS; + vector offset_src; + cmtl::cm_vector_assign(offset_src.select_all(), 0, iter); + auto f16_u16 = [] (vector_ref in, vector_ref out) { + static const ushort HIGH_BIT = 1 << 15; + auto in_u16 = in.format(); + auto mask = (in_u16 & HIGH_BIT) != 0; + vector m; + m.merge(ushort{0xffff}, HIGH_BIT, mask); + out = in_u16 ^ m; + }; + auto u16_f16 = [] (vector_ref in, vector_ref out) { + static const ushort HIGH_BIT = 1 << 15; + auto mask = (in & HIGH_BIT) != 0; + vector m; + m.merge(HIGH_BIT, ushort{0xffff}, mask); + out.format() = in ^ m; + }; + if constexpr(std::is_same::value) { + { + // f16 to u16 + vector data; + vector data_f; + int i; + for (i = 0; i + THREADS <= n / THREADS * THREADS; i += THREADS) { + data_f.format() = cm_ptr_load((int*)src, i * (int)sizeof(short)); + f16_u16(data_f, data); + cm_ptr_store((int*)src, i * (int)sizeof(short), data.format()); + } + if (i < n) { + auto pos = seq_u32 + i; + SIMD_IF_BEGIN (pos < n) { + data_f = cm_ptr_load((half*)src, pos * (uint)sizeof(half)); + f16_u16(data_f, data); + cm_ptr_store((ushort*)src, pos * (uint)sizeof(ushort), data); + } SIMD_IF_END; + } + } + } + { + // generate idx + vector data; + int i; + for (i = 0; i + THREADS <= n / THREADS * THREADS; i += THREADS) { + data = seq + i; + cm_ptr_store((int*)cur_idx, i * (int)sizeof(short), data.format()); + } + if (i < n) { + auto pos = seq_u32 + i; + data = seq + i; + + SIMD_IF_BEGIN (pos < n) { + cm_ptr_store((ushort*)cur_idx, pos * (uint)sizeof(ushort), data); + } SIMD_IF_END; + } + } + // 4bit per pass, 4 pass for f16 + for (int pass = 0; pass < 4; pass++) { + { + // slm layout: short [16][work items] = 16*32*2 bytes + vector data = 0; + for (int i = 0; i < 16 * THREADS * sizeof(ushort); i += 256) + cm_slm_block_write(slm, i, data); + } + { + // counting phase + vector data; + for (int i = 0; i < iter; i++) { + data = cm_ptr_load((ushort*)sorted_src, (offset_src + i) * (uint)sizeof(ushort), (offset_src + i) < n); + vector bits = 0xf - ((data >> (pass * 4)) & 0xf); + vector addr = bits * THREADS + seq; + vector total; + cm_slm_read(slm, addr, total); + total += 1; + cm_slm_write(slm, addr, total); + } + } + { + // prefix sum + vector local_prefix = 0; + + vector seq_prefix; + cmtl::cm_vector_assign(seq_prefix.select_all(), 0, THREADS); + + #pragma unroll + for (ushort i = 0; i < THREADS; i++) { + auto prev = local_prefix; + vector hist; + vector addr = seq_prefix + i; + cm_slm_read(slm, addr, hist); + local_prefix += hist; + cm_slm_write(slm, addr, prev); + } + // Hillis-Steele scan + vector local_tmp; + local_tmp.select<15, 1>(1) = local_prefix.select<15, 1>(1) + local_prefix.select<15, 1>(0); local_tmp[0] = local_prefix[0]; + local_prefix.select<14, 1>(2) = local_tmp.select<14, 1>(2) + local_tmp.select<14, 1>(0); local_prefix.select<2, 1>(0) = local_tmp.select<2, 1>(0); + local_tmp.select<12, 1>(4) = local_prefix.select<12, 1>(4) + local_prefix.select<12, 1>(0); local_tmp.select<4, 1>(0) = local_prefix.select<4, 1>(0); + local_prefix.select<8, 1>(8) = local_tmp.select<8, 1>(8) + local_tmp.select<8, 1>(0); local_prefix.select<8, 1>(0) = local_tmp.select<8, 1>(0); + vector data; + cm_slm_block_read(slm, 0, data); + #pragma unroll + for (int i = 1; i < 16; i++) { + data.select(i * THREADS) += local_prefix[i - 1]; + } + cm_slm_block_write(slm, 0, data); + } + { + // reorder + vector data; + for (int i = 0; i < iter; i++) { + data = cm_ptr_load((ushort*)sorted_src, (offset_src + i) * (uint)sizeof(ushort), (offset_src + i) < n); + vector bits = 0xf - ((data >> (pass * 4)) & 0xf); + vector addr = bits * THREADS + seq; + vector index; + cm_slm_read(slm, addr, index); + vector offset_i32 = index * (uint)sizeof(ushort); + cm_ptr_store((ushort*)sorted_tmp, offset_i32, data, index < n); + + data = cm_ptr_load((ushort*)cur_idx, (offset_src + i) * (uint)sizeof(ushort)); + cm_ptr_store((ushort*)cur_idx_tmp, offset_i32, data, (offset_src + i) < n); + + index += 1; + cm_slm_write(slm, addr, index); + } + auto tmp = sorted_src; + sorted_src = sorted_tmp; + sorted_tmp = tmp; + tmp = cur_idx; + cur_idx = cur_idx_tmp; + cur_idx_tmp = tmp; + } + } + { + // copy to output + vector data; + vector data_f; + int i; + for (i = 0; i + THREADS <= n / THREADS * THREADS; i += THREADS) { + data.format() = cm_ptr_load((int*)src, i * (int)sizeof(short)); + if constexpr(std::is_same::value) { + u16_f16(data, data_f); + cm_ptr_store((int*)sorted_value, i * (int)sizeof(short), data_f.format()); + } else { + cm_ptr_store((int*)sorted_value, i * (int)sizeof(short), data.format()); + } + } + if (i < n) { + auto pos = seq_u32 + i; + if constexpr(std::is_same::value) { + SIMD_IF_BEGIN (pos < n) { + data = cm_ptr_load((ushort*)src, pos * (uint)sizeof(half)); + u16_f16(data, data_f); + cm_ptr_store((half*)sorted_value, pos * (uint)sizeof(half), data_f); + } SIMD_IF_END; + } else { + SIMD_IF_BEGIN (pos < n) { + data = cm_ptr_load((ushort*)src, pos * (uint)sizeof(ushort)); + cm_ptr_store((ushort*)sorted_value, pos * (uint)sizeof(ushort), data); + } SIMD_IF_END; + } + } + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm new file mode 100644 index 00000000000000..b019d217bc99c4 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm @@ -0,0 +1,135 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include +#include + +#ifndef ATTR +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +constexpr uint wg_size = WG_SIZE; +#define REG_K 16 + +extern "C" _GENX_MAIN_ void KERNEL_NAME( + const half* key [[type("svmptr_t")]], + const half* value [[type("svmptr_t")]], + const int32_t* past_lens [[type("svmptr_t")]], + const int32_t* block_indices [[type("svmptr_t")]], + const int32_t* block_indices_begins [[type("svmptr_t")]], + const int32_t* subsequence_begins [[type("svmptr_t")]], +#if KV_CACHE_COMPRESSION_PER_TOKEN + uint8_t* key_cache [[type("svmptr_t")]], + uint8_t* value_cache [[type("svmptr_t")]], +#else + half* key_cache [[type("svmptr_t")]], + half* value_cache [[type("svmptr_t")]], +#endif + uint32_t key_pitch, + uint32_t key_offset, + uint32_t value_pitch, + uint32_t value_offset, + uint32_t batch_size_in_sequences) { + // # key: [batch_size_in_tokens, num_kv_heads * k_head_size] + // # value [batch_size_in_tokens, num_kv_heads * v_head_size] + // # key_cache: [num_blocks, num_heads, block_size, k_head_size] + // # value_cache: [num_blocks, num_heads, block_size, v_head_size] + // + // # past_lens: [sequences_num] + // # subsequence_begins: [sequences_num + 1] + // # block_indices: [used_blocks_num] + // # block_indices_begins: [sequences_num + 1] + + // wg_count = aligned_to(batch_size_in_tokens, wg_size) // wg_size + // # GWS [1, num_heads, wg_count * wg_size] + // # LWS [1, 1, wg_size] + + const auto head_idx = cm_group_id(1); + const auto wg_id = cm_group_id(2); + const auto wg_local_id = cm_local_id(2); + const auto local_size = cm_local_size(2); + + const uint token_idx = cm_global_id(2); + + // token_idx -> subsequence_idx + if (token_idx >= subsequence_begins[batch_size_in_sequences]) return; + uint subsequence_idx = 0; + for (uint i = 0; i < batch_size_in_sequences; i++) { + if (token_idx >= subsequence_begins[i] && token_idx < subsequence_begins[i + 1]) { + subsequence_idx = i; + break; + } + } + + const uint subsequence_begin_idx = subsequence_begins[subsequence_idx]; + const uint past_len = past_lens[subsequence_idx]; + const uint current_block_idx = (past_len + token_idx - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; + const uint token_start_pos = (past_len + token_idx - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE; + const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx; + + #if KV_CACHE_COMPRESSION_PER_TOKEN + // Assume: K_HEAD_SIZE == K_HEAD_SIZE + auto quantize_and_store = [&](vector data, uchar* out, uint out_offset, uint token_pos) { + uint scale_offset = out_offset + K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + token_pos * sizeof(half); + half max_val = cm_reduced_max(data); + half min_val = cm_reduced_min(data); + half scale_val = half(0.0); + half zp_val = half(0.0); + if(max_val == min_val) { + scale_val = half(0.0); + zp_val = max_val; + } else { + scale_val = 255.0 / (max_val - min_val); + zp_val = (0.0 - min_val) * scale_val; + } + vector dequant_data = cm_mul(data, scale_val) + zp_val; + vector data_u8 = cm_rnde(dequant_data); + cm_ptr_store((uint32_t*)(out + out_offset + token_pos * K_HEAD_SIZE), 0, data_u8.format()); + half *out_scale_zp = (half*)(out + scale_offset); + out_scale_zp[0] = (max_val - min_val) / 255.0; + out_scale_zp[PAGED_ATTENTION_BLOCK_SIZE] = zp_val; + }; + #endif + + { + uint block_k_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; + uint key_out_offset = block_k_base_offset + token_start_pos * K_HEAD_SIZE; + uint key_in_offset = token_idx * key_pitch + head_idx * K_HEAD_SIZE + key_offset; + + vector key_data; + key_data.format() = cm_ptr_load((int*)key, key_in_offset * (int)sizeof(half)); + + #if KV_CACHE_COMPRESSION_PER_TOKEN + quantize_and_store(key_data, (uchar*)key_cache, block_k_base_offset, token_start_pos); + #else + cm_ptr_store((int*)key_cache, key_out_offset * (int)sizeof(half), key_data.format()); + #endif + } + { + uint block_v_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_V_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; + uint value_out_offset = block_v_base_offset + token_start_pos * V_HEAD_SIZE; + uint value_in_offset = token_idx * value_pitch + head_idx * V_HEAD_SIZE + value_offset; + + vector value_data; + value_data.format() = cm_ptr_load((int*)value, value_in_offset * (int)sizeof(half)); + #if KV_CACHE_COMPRESSION_PER_TOKEN + quantize_and_store(value_data, (uchar*)value_cache, block_v_base_offset, token_start_pos); + #else + cm_ptr_store((int*)value_cache, value_out_offset * (int)sizeof(half), value_data.format()); + #endif + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm new file mode 100644 index 00000000000000..4410060e47ca56 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm @@ -0,0 +1,181 @@ + +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "cm_pa_common.hpp" + +#ifdef CM_HAS_LSC_UNTYPED_2D +#define USE_LSC 1 +#else +#define USE_LSC 0 +#endif + +extern "C" _GENX_MAIN_ void KERNEL_NAME( + //query [q_len, num_heads, S] + half* query [[type("svmptr_t")]], +#if CMPA_KVCACHE_U8 + int8_t* k_cache [[type("svmptr_t")]], + int8_t* v_cache [[type("svmptr_t")]], +#else + half* k_cache [[type("svmptr_t")]], + half* v_cache [[type("svmptr_t")]], +#endif + int32_t* past_lens [[type("svmptr_t")]], + int32_t* block_indices [[type("svmptr_t")]], + int32_t* block_indices_begins [[type("svmptr_t")]], + int32_t* subsequence_begins [[type("svmptr_t")]], + half* output [[type("svmptr_t")]], +#if SPARSE_BLOCK_SIZE > 1 + bool* sparse_block_mask [[type("svmptr_t")]], + bool* sparse_block_mask_wg [[type("svmptr_t")]], + int q_len, + int num_q_blocks, + int num_k_blocks, + // validate sparse atten process + bool validate) { +#else + int q_len) { +#endif + constexpr int is_causal = CMFLA_IS_CAUSAL; + constexpr int num_heads = CMFLA_NUM_HEADS; + constexpr int head_size = CMFLA_HEAD_SIZE; + constexpr int num_kv_heads = CMFLA_NUM_KV_HEADS; + constexpr int pa_block_sz = CMPA_BLOCK_SZ; + //# query [q_len, num_heads, S] + //# k_cache [kv_len, num_heads, S] + //# v_cache [kv_len, num_heads, S] +#if CMPA_KVCACHE_U8 + constexpr uint K_SLM_SIZE = (4*kv_step * head_size * sizeof(half)); + constexpr uint V_SLM_SIZE = (4*kv_step * head_size * sizeof(half)); + constexpr uint Q_SLM_SIZE = 0; + + cm_slm_init(K_SLM_SIZE + V_SLM_SIZE + Q_SLM_SIZE); + + auto slm_K = cm_slm_alloc(K_SLM_SIZE); + auto slm_V = cm_slm_alloc(V_SLM_SIZE); + +#endif + auto batch = cm_group_id(0); + auto h = cm_group_id(1); + auto hkv = h / (num_heads/num_kv_heads); + auto wg_id = cm_group_id(2); // each work-group handles a sequence + auto wg_local_id = cm_local_id(2); + int local_size = cm_local_size(2); + + int q_start_sg, kv_start, kv_seq_len, q_len_sg; + + // multiple work-groups are required to split a sequence, + // need to figure out which part of query-tokens to process + int wg_seq_len = local_size * q_step; + int past_q_lens = past_lens[0]; + kv_start = 0; + kv_seq_len = q_len + past_q_lens; + q_start_sg = (wg_id * local_size + wg_local_id) * q_step; + q_len_sg = q_step; + if (q_start_sg + q_len_sg > q_len) { + q_len_sg = q_len - q_start_sg; + } + + // qkv is fused + int kv_stop = kv_seq_len; + if constexpr (is_causal) { + /* + -------------------------------- + | | | | | + | 00 | | | | + | | | | | + -------------------------------- + | | | | | + | 10 | 11 | | | + | | | | | + --------------------------------- + | | | | | + | 20 | 21 | 22 | | + | | | | | + --------------------------------- + | | | | | + | 30 | 31 | 32 | 33 | + | | | | | + --------------------------------- + each grid can be [q_len_per_trunk, q_len_per_trunk]. + For each trunk, [q_len_per_trunk, past_q_lens] must be calculated. Such as: `20`,`21`. but for the 22, + casual mask optimization can be applied. differnt wgs would has different kv stop. + //todo:kv_stop is wg level, should we change to sg level? + sglevel would cause sgs in one wg diverge. so leave for now. also one wg has same kvstop makes eaiser for kv copying/loading into SLM/cache. + */ + kv_stop = (wg_id + 1) * wg_seq_len + past_q_lens; + if (kv_stop > kv_seq_len) kv_stop = kv_seq_len; + } + + //Q/O[B, L, H, S] + uint q_offset = (q_start_sg*num_heads + h)*head_size; + +#if SPARSE_BLOCK_SIZE > 1 + bool *block_mask_base, *wg_block_mask_base; + if (validate) { + //# sparse_block_mask [num_heads, num_q_blocks, num_k_blocks] + //# sparse_block_mask_wg [num_heads, wg_count_along_query, num_k_blocks] + auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE; + block_mask_base = sparse_block_mask + (h * num_q_blocks + q_start_block) * num_k_blocks; + wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id) * num_k_blocks; + } + #endif + +#if CMPA_KVCACHE_U8 + uint kv_offset = hkv*(head_size+4)*pa_block_sz; + pa_lsc_u8( + slm_K, + slm_V, + wg_local_id, + local_size, + q_start_sg, //q_start for SG, + kv_stop, + q_len_sg, //q_step, + kv_seq_len, //kv_len, + reinterpret_cast(query + q_offset), + reinterpret_cast(k_cache + kv_offset), + reinterpret_cast(v_cache + kv_offset), +#if SPARSE_BLOCK_SIZE > 1 + reinterpret_cast(block_mask_base), + reinterpret_cast(wg_block_mask_base), + validate, +#endif + reinterpret_cast(output + q_offset), + past_q_lens, + block_indices); +#else + uint kv_offset = hkv*head_size*pa_block_sz; + pa_kernel_lsc_prefetch_f16( + wg_local_id, + q_start_sg, //q_start for SG, + kv_stop, + q_len_sg, //q_step, + kv_seq_len, //kv_len, + reinterpret_cast(query + q_offset), + reinterpret_cast(k_cache + kv_offset), + reinterpret_cast(v_cache + kv_offset), +#if SPARSE_BLOCK_SIZE > 1 + reinterpret_cast(block_mask_base), + reinterpret_cast(wg_block_mask_base), + validate, +#endif + reinterpret_cast(output + q_offset), + past_q_lens, + block_indices); +#endif +} +} // namespace KERNEL_NAME \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm new file mode 100644 index 00000000000000..5db4e9a5dc3b8e --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm @@ -0,0 +1,456 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +// xe-1:8, xe-2:16 +#if XE_ARCH==1 +#define REG_N 8 +#define USE_LSC_BLOCK_2D_DESC 0 +#else +#define REG_N 16 +#define USE_LSC_BLOCK_2D_DESC 1 +#endif + +#define SystolicDepth 8 +#define RepeatCount 1 +#define VNNI_WIDTH 2 +#define REG_K (SystolicDepth * VNNI_WIDTH) +#define REG_M RepeatCount +#define KV_PARTITION_STEP_NUM (KV_PARTITION_SIZE / KV_STEP) + + +#define Q_SLICE_NUM (HEADS_NUM / KV_HEADS_NUM) +#if Q_SLICE_NUM > 8 || Q_SLICE_NUM == 1 +#define Q_RepeatCount 1 +#else +#define Q_RepeatCount Q_SLICE_NUM +#endif + +#if KV_CACHE_COMPRESSION + // scale/zp is half-precision, so size = 2 * 2 = 4 bytes + #define KV_SCALE_ZP_SIZE 4 // scale/zp bytes + #define KV_ELEMENT_TYPE uint8_t +#else + #define KV_SCALE_ZP_SIZE 0 // no scale/zp + #define KV_ELEMENT_TYPE half +#endif + +//prepack [K, N] to [K/2, N, 2] layout. +template +inline void prepack_to_VNNI_W2(matrix_ref input, matrix_ref out) { + #pragma unroll + for (int r = 0; r < K/2; r++) { + out.row(r).select(0) = input.row(r*2); + out.row(r).select(1) = input.row(r*2+1); + } +} + +extern "C" _GENX_MAIN_ void KERNEL_NAME( + half* query [[type("svmptr_t")]], + KV_ELEMENT_TYPE* key [[type("svmptr_t")]], + KV_ELEMENT_TYPE* value [[type("svmptr_t")]], + int* past_lens [[type("svmptr_t")]], + int* block_indices [[type("svmptr_t")]], + int* block_indices_begins [[type("svmptr_t")]], + int* subsequence_begins [[type("svmptr_t")]], + float* output [[type("svmptr_t")]], + float* lse [[type("svmptr_t")]], + int q_len// 1 + ) { + //# batch=1, seq_num=1 or >1 + //# query [seq_idx, seq_num, head_num, head_size] + //# output[seq_idx, seq_num, head_num, head_size] + //# key [block_num, head_num, block_size, head_size] + [block_num, head_num, block_size, 4] (scale/zp) + //# value [block_num, head_num, block_size, head_size] + [block_num, head_num, block_size, 4] (scale/zp) + + //# KV_PARTITION_SIZE should be multiple of kv_block_size(KV_BLOCK_SIZE) + //# kv_len dimision will be split into multiple partitions, each WG process a partition + //# total_partitions_num = kv_len // KV_PARTITION_SIZE + //# GWS=[seq_num, num_heads, total_partitions_num] + //# LWS=[1, 1, 1] + + //# Each WG processes a partition, which is KV_PARTITION_SIZE long and multiple of KV_BLOCK_SIZE. + //# KV_BLOCK_SIZE can be 32/64/128/256, etc. + const auto seq_idx = cm_global_id(0); + const auto kv_head_num_idx = cm_global_id(1); + const auto head_num_idx = kv_head_num_idx * (HEADS_NUM/KV_HEADS_NUM); + //# KV_PARTITION_SIZE --> EU thread + const auto wg_thread_id = cm_global_id(2); + const uint kv_partition_num = cm_group_count(2); + const uint kv_partition_idx = cm_group_id(2); + + const uint kv_len = past_lens[seq_idx] + 1; + const uint start_block_idx = block_indices_begins[seq_idx] + kv_partition_idx * (KV_PARTITION_SIZE / KV_BLOCK_SIZE); + + if(kv_partition_idx * KV_PARTITION_SIZE > kv_len) { + return; + } + const uint total_blocks_num = (kv_len + KV_BLOCK_SIZE - 1) / KV_BLOCK_SIZE; + constexpr uint kv_pitch = HEAD_SIZE * sizeof(KV_ELEMENT_TYPE); + + //# Load Q into register(as dpas-A tile) + const uint qo_offset = (seq_idx*HEADS_NUM*q_len + head_num_idx)*HEAD_SIZE; + + #if Q_RepeatCount != 1 + matrix Qmat = 0; + cm_svm_block_read((svmptr_t)(query + qo_offset), Qmat.format()); + #else + matrix Qmat = 0; + cm_svm_block_read((svmptr_t)(query + qo_offset), Qmat.format()); + #endif + + constexpr uint per_kv_block_element_num = KV_BLOCK_SIZE * KV_HEADS_NUM * (HEAD_SIZE + KV_SCALE_ZP_SIZE / sizeof(KV_ELEMENT_TYPE)); // 4 bytes: scale/zp + uint block_num = KV_PARTITION_SIZE / KV_BLOCK_SIZE; + + uint leftover_size = 0; + if(kv_partition_idx == kv_partition_num - 1) { + leftover_size = (kv_len - KV_PARTITION_SIZE * kv_partition_idx) % KV_PARTITION_SIZE; + } + if(block_num > total_blocks_num - start_block_idx) { + block_num = total_blocks_num - start_block_idx; + } + + //# rS = Q @ Kt + #if Q_RepeatCount != 1 + matrix rS = 0; + #else + matrix rS = 0; + #endif + // # Each SG can process multiple blocks + #pragma unroll + for(uint block_idx = 0, ki = 0; block_idx < block_num; block_idx++) { + uint blk_indices = block_indices[start_block_idx + block_idx]; + uint kv_base_offset = blk_indices * per_kv_block_element_num + kv_head_num_idx * (per_kv_block_element_num / KV_HEADS_NUM); + uint kv_scale_zp_offset = kv_base_offset + KV_BLOCK_SIZE * HEAD_SIZE; // scale/zp offset + + #if USE_LSC_BLOCK_2D_DESC + #if KV_CACHE_COMPRESSION + // Transpose only support dword and qword + lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + #else + lsc::block_2d_desc b2dK(reinterpret_cast(key + kv_base_offset), KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); + #endif + #else + uint kv_offset = kv_base_offset; + uint kv_stride = HEAD_SIZE; + uint kv_x0 = 0, kv_y0 = 0; + uint kv_x1 = HEAD_SIZE*sizeof(KV_ELEMENT_TYPE); + uint kv_y1 = KV_BLOCK_SIZE; + #endif + + uint kv_pos_end = KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && leftover_size > 0) { + kv_pos_end = leftover_size % KV_BLOCK_SIZE; + if(kv_pos_end == 0) kv_pos_end = KV_BLOCK_SIZE; + } + + #if KV_CACHE_COMPRESSION + // load scale/zp + vector scale_vec; + vector zp_vec; + cm_svm_block_read(reinterpret_cast(key + kv_scale_zp_offset), scale_vec); + cm_svm_block_read(reinterpret_cast(key + kv_scale_zp_offset + KV_BLOCK_SIZE * sizeof(half)), zp_vec); + if(kv_pos_end < KV_BLOCK_SIZE) { + // fill leftover with last valid scale/zp + #pragma unroll + for(int i = kv_pos_end; i < KV_BLOCK_SIZE; i++) { + scale_vec[i] = 0.0; + zp_vec[i] = 0.0; + } + } + #endif + + for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += KV_STEP, ki++) { + #if KV_CACHE_COMPRESSION + vector temp_scale, temp_zp; + temp_scale.select(0) = scale_vec.select(kv_pos); + temp_scale.select(1) = scale_vec.select(kv_pos); + temp_zp.select(0) = zp_vec.select(kv_pos); + temp_zp.select(1) = zp_vec.select(kv_pos); + #endif + + #pragma unroll + #if KV_CACHE_COMPRESSION + for(int k = 0, ri = 0; k < HEAD_SIZE/4; k += REG_K/4, ri ++ ) { + #else + for(int k = 0, ri = 0; k < HEAD_SIZE/2; k += REG_K/2, ri ++ ) { + #endif + matrix Kt = 0; + #if USE_LSC_BLOCK_2D_DESC + //# Load Kt into register & pack as VNNI(as dpas-B tile) + //# DWORD transposed load == (transposed + VNNI) load + b2dK.set_block_x(k); + + #if KV_CACHE_COMPRESSION + // dequantize + matrix Kt_quant_temp, Kt_quant; + cm_load(Kt_quant_temp.format(), b2dK.set_block_y(kv_pos)); + auto quant_src = Kt_quant_temp.format(); + auto quant_dst = Kt_quant.format(); + + #pragma unroll + for(int r = 0; r < REG_K / 2; r += 2) { + quant_dst.row(r ) = quant_src.select<2,1,8,2>(r,0); + quant_dst.row(r+1) = quant_src.select<2,1,8,2>(r,1); + } + + #pragma unroll + for(int r = 0; r < REG_K; r++) { + Kt[r] = Kt_quant[r] - temp_zp.format()[r%2]; //vector - vector + Kt[r] = cm_mul(Kt[r], temp_scale.format()[r%2]); // vector * vector + } + #else + cm_load(Kt.format(), b2dK.set_block_y(kv_pos)); + #endif + #else + matrix temp; + uint cur_kv_offset = kv_offset + kv_pos * kv_stride + k * 2;// uint --> half + #pragma unroll + for(int kk = 0; kk < REG_N; kk++) { + cm_svm_block_read((svmptr_t)(key + cur_kv_offset + kk * kv_stride), temp[kk].format()); + } + #if XE_ARCH==1 + Transpose_8x8(temp.select<8,1,8,1>(0,0), Kt.format().select<8,1,8,1>(0,0)); + #else + Transpose_8x8(temp.select<8,1,8,1>(0,0), Kt.format().select<8,1,8,1>(0,0)); + Transpose_8x8(temp.select<8,1,8,1>(8,0), Kt.format().select<8,1,8,1>(0,8)); + #endif + #endif + #if Q_RepeatCount != 1 + matrix Qmat_data = Qmat.select(0, ri*REG_K); + matrix rS_data = 0; + for(int qi = 0; qi < Q_SLICE_NUM; qi ++) { + Qmat_data[qi] = Qmat[qi].format()[ri]; + } + rS_data = cm_dpas( + rS_data.format(), + Kt.format(), + Qmat_data.format()); + rS.select(0, ki*REG_N) += rS_data; + #else + #pragma unroll + for(int qi = 0; qi < Q_SLICE_NUM; qi ++) { + auto Qmat_slice = Qmat[qi].format(); + auto rSvec = rS[qi].format()[ki].format(); + rSvec = cm_dpas( + rSvec, + Kt.format(), + Qmat_slice[ri].format()); + } + #endif + } + } + } + + // online softmax + vector cur_sum = 0.0f; + vector cur_lse = 0.0f; + #if XE_ARCH==1 + matrix Pmat = 0; + #else + #if Q_RepeatCount != 1 + matrix Pmat = 0; + #else + matrix Pmat = 0; + #endif + #endif + #pragma unroll + for(int qi = 0; qi < Q_SLICE_NUM; qi++) { + auto rS_slice = rS[qi].format(); + rS_slice = cm_mul(rS_slice, (float)SCALE_FACTOR); // convert scale_factor into (float), or it will be promoted to double + + if(leftover_size > 0) { + auto Svec = rS_slice.format(); + for(int i = leftover_size; i < KV_PARTITION_STEP_NUM * REG_N; i++){ + Svec[i] = -3e38f; + } + } + + // compute lse + constexpr float log2e = 1.4426950408889634f; + constexpr float loge2 = 0.6931471805599453f; + vector rS_exp = cm_exp(rS_slice.format()*log2e); + + // compute row_max + auto rSv = rS_slice.format(); + float row_max = rSv[0]; + // It is performance hotspot for u8, must add unroll + #if KV_CACHE_COMPRESSION + #pragma unroll + #endif + for(int r = 1; r < rSv.n_elems(); r++) + row_max = cm_max(row_max, rSv[r]); + + // compute P = exp(rS_slice - row_max) + #if XE_ARCH==1 + Pmat[qi].format() = cm_exp((rS_slice.format() - row_max)*log2e); + #else + Pmat[qi].format() = cm_exp((rS_slice - row_max)*log2e); + #endif + + vector rS_exp_temp = cm_exp((rS_slice.format() - row_max)*log2e); + cur_lse[qi] = cm_sum(rS_exp_temp.format()); + cur_lse[qi] = cm_log(cur_lse[qi]) * loge2 + row_max; // log2(sum(exp(x))) = log2e * log(sum(exp(x))) + + // compute row sum of P + auto rPv = Pmat[qi].format(); + cur_sum[qi] = cm_sum(rPv[0]); + } + + //# rO = P * V + #if Q_RepeatCount != 1 + matrix Omat = 0; + #else + matrix Omat = 0; + #endif + #pragma unroll + for(uint block_idx = 0, ki = 0; block_idx < block_num; block_idx++) { + uint blk_indices = block_indices[start_block_idx + block_idx]; + uint kv_base_offset = blk_indices * per_kv_block_element_num + kv_head_num_idx * (per_kv_block_element_num / KV_HEADS_NUM); + uint kv_scale_zp_offset = kv_base_offset + KV_BLOCK_SIZE * HEAD_SIZE; // scale/zp offset + + #if USE_LSC_BLOCK_2D_DESC + //# vector load cannot be used for block_2d_desc + //# note: candidate template ignored: deduced type 'details::Block2DRefTy' (aka 'vector_ref') of 1st parameter + #if KV_CACHE_COMPRESSION + lsc::block_2d_desc b2dV(value + kv_base_offset, KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); + #else + lsc::block_2d_desc b2dV(value + kv_base_offset, KV_BLOCK_SIZE - 1, HEAD_SIZE*sizeof(half) - 1, kv_pitch - 1, 0, 0); + #endif + #else + uint kv_offset = kv_base_offset; + uint kv_stride = HEAD_SIZE; + uint kv_x0 = 0, kv_y0 = 0; + uint kv_x1 = HEAD_SIZE*sizeof(half); + uint kv_y1 = KV_BLOCK_SIZE; + #endif + + uint kv_pos_end = KV_BLOCK_SIZE; + if(block_idx == block_num - 1 && leftover_size > 0) { + kv_pos_end = leftover_size % KV_BLOCK_SIZE; + if(kv_pos_end == 0) kv_pos_end = KV_BLOCK_SIZE; + } + + #if KV_CACHE_COMPRESSION + // load scale/zp + vector scale_vec; + vector zp_vec; + cm_svm_block_read(reinterpret_cast(value + kv_scale_zp_offset), scale_vec); + cm_svm_block_read(reinterpret_cast(value + kv_scale_zp_offset + KV_BLOCK_SIZE * sizeof(half)), zp_vec); + if(kv_pos_end < KV_BLOCK_SIZE) { + // fill leftover with last valid scale/zp + for(int i = kv_pos_end; i < KV_BLOCK_SIZE; i++) { + scale_vec[i] = 0.0; + zp_vec[i] = 0.0; + } + } + #endif + #pragma unroll + for(int kv_pos = 0; kv_pos < kv_pos_end; kv_pos += REG_K, ki++) { + #if KV_CACHE_COMPRESSION + vector temp_scale = scale_vec.select(kv_pos); + vector temp_zp = zp_vec.select(kv_pos); + #endif + #pragma unroll + for(int k = 0, ri = 0; k < HEAD_SIZE; k += REG_N, ri ++ ) { + // Load V into register & pack as VNNI(as dpas-B tile) + matrix VmatNormal; + matrix Vmat; + #if USE_LSC_BLOCK_2D_DESC + b2dV.set_block_x(k); + #if KV_CACHE_COMPRESSION + // dequantize + matrix Vt_quant; + cm_load(Vt_quant.format(), b2dV.set_block_y(kv_pos)); + + #pragma unroll + for(int r = 0; r < REG_K; r++) { + VmatNormal[r] = Vt_quant[r] - temp_zp[r]; // vector - scalar + VmatNormal[r] = cm_mul(VmatNormal[r], temp_scale[r]); // vector * scalar + } + + if(kv_pos_end - kv_pos < KV_STEP) { + #pragma unroll + for(int r = kv_pos_end; r()); + #else + cm_load(Vmat[0].format(), b2dV.set_block_y(kv_pos)); + // Sometimes KV cache would be filled with random NAN(found in PTL), so need to clean up the unused value data. + if(kv_pos_end - kv_pos < KV_STEP) { + auto VmatRef = Vmat[0].format(); + uint valid_rows = kv_pos_end - kv_pos; + uint valid_rows_vnni = (valid_rows+1)/2; + for (int r = valid_rows_vnni; r < KV_STEP / 2; r++) + VmatRef.row(r) = 0.f; + if (valid_rows % 2 == 1) + VmatRef.row(valid_rows_vnni-1).select(1) = 0.f; + } + #endif + #else + matrix temp; + uint cur_kv_offset = kv_offset + kv_pos * kv_stride + k; + #pragma unroll + for(int kk = 0; kk < REG_K; kk++) { + cm_svm_block_read((svmptr_t)(value + cur_kv_offset + kk * kv_stride), temp[kk].format()); + } + auto Vref = Vmat[0].format(); + Vref.select(0, 0) = temp.select(0, 0); + Vref.select(0, 1) = temp.select(1, 0); + #endif + #if Q_RepeatCount != 1 + matrix Pmat_data = Pmat.select(0, ki*REG_K); + matrix Omat_data = 0; + Omat_data = cm_dpas( + Omat_data.format(), + Vmat[0].format(), + Pmat_data.format()); + Omat.select(0, ri*REG_N) += Omat_data; + #else + for(int qi = 0; qi < Q_SLICE_NUM; qi ++) { + auto Pmat_slice = Pmat[qi].format(); + auto Omat_slice = Omat[qi].format(); + Omat_slice[ri] = cm_dpas( + Omat_slice[ri], + Vmat[0].format(), + Pmat_slice[ki].format()); + } + #endif + } + } + } + + //# save Output + for (int qi = 0; qi < Q_SLICE_NUM; qi++) { + matrix cur_O_f32; + uint o_offset = seq_idx * kv_partition_num * KV_HEADS_NUM * HEAD_SIZE + kv_partition_num * (head_num_idx + qi) * HEAD_SIZE + wg_thread_id * HEAD_SIZE; + float div_cur_sum = 1.0/cur_sum[qi]; + auto Omat_slice = Omat[qi].format(); + #pragma unroll + for(int k = 0, ri=0; k < HEAD_SIZE; k += REG_N, ri++) { + auto cO = Omat_slice[ri].format(); + #if XE_ARCH==1 + cur_O_f32= cm_mul(cO, div_cur_sum); + #else + cur_O_f32= cm_div_ieee(cO, cur_sum[qi]); + #endif + cm_svm_block_write((svmptr_t)(output + o_offset + k),cur_O_f32.format()); + } + uint lse_offset = seq_idx * KV_HEADS_NUM * kv_partition_num + (head_num_idx + qi) * kv_partition_num + wg_thread_id; + lse[lse_offset] = cur_lse[qi]; + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm new file mode 100644 index 00000000000000..501ffe48e3ec69 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token_finalization.cm @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +//cm_sdpa_2nd_reduce +extern "C" _GENX_MAIN_ void KERNEL_NAME( + float* input [[type("svmptr_t")]], + half* output [[type("svmptr_t")]], + float* lse [[type("svmptr_t")]], + int kv_partition_num + ) { + auto batch = cm_global_id(0); + auto head = cm_global_id(1); + auto offset = cm_group_id(2) * REDUCE_SPLIT_SIZE; + const int total_partition_num = (kv_partition_num * HEADS_NUM); + + // load lse + float total_lse = 0.0; + uint lse_offset = batch * total_partition_num + head * kv_partition_num; + constexpr float log2e = 1.4426950408889634f; + float* lse_vec = lse + lse_offset; + float lse_max = lse_vec[0]; + for(int k = 1; k < kv_partition_num; k ++) { + lse_max = cm_max(lse_vec[k], lse_max); + } + + int iter = kv_partition_num / 16; + for(int k = 0; k < iter * 16; k += 16) { + #pragma unroll + for(int ki = k; ki < k + 16; ki ++) { + float lse_value = cm_exp((lse_vec[ki] - lse_max)*log2e); + total_lse += lse_value; + } + } + for(int k = iter * 16; k < kv_partition_num; k ++) { + float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); + total_lse += lse_value; + } + + // load input, total_partition_num = head_nums * kv_partition_num; + matrix out_mat_f32 = 0; + matrix out_mat = 0; + matrix data_mat; + uint input_offset = batch * total_partition_num * HEAD_SIZE + head * kv_partition_num * HEAD_SIZE + offset; + + for(int k = 0; k < iter * 16; k += 16) { + #pragma unroll + for(int ki = k; ki < k + 16; ki ++) { + cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); + input_offset += HEAD_SIZE; + float lse_value = cm_exp((lse_vec[ki] - lse_max)*log2e); + out_mat_f32 += cm_mul(data_mat, (float)(lse_value/total_lse)); + } + } + for(int k = iter * 16; k < kv_partition_num; k ++) { + cm_svm_block_read((svmptr_t)(input + input_offset), data_mat.format()); + input_offset += HEAD_SIZE; + float lse_value = cm_exp((lse_vec[k] - lse_max)*log2e); + out_mat_f32 += cm_mul(data_mat, (float)(lse_value/total_lse)); + } + out_mat = out_mat_f32.format(); + + // write output + uint output_offset = batch * HEADS_NUM * HEAD_SIZE + head * HEAD_SIZE + offset; + cm_svm_block_write((svmptr_t)(output + output_offset),out_mat.format()); + } \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp new file mode 100644 index 00000000000000..f2ddcfd67de1f6 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp @@ -0,0 +1,226 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "paged_attention.hpp" + +#include +#include +#include +#include + +#include "common_utils/jitter.hpp" +#include "common_utils/kernel_generator_base.hpp" +#include "intel_gpu/graph/kernel_impl_params.hpp" +#include "intel_gpu/primitives/paged_attention.hpp" +#include "kv_cache_inst.h" +#include "openvino/core/partial_shape.hpp" +#include "paged_attention_gen.hpp" +#include "paged_attention_inst.h" +#include "primitive_cm_base.hpp" +#include "primitive_inst.h" + +namespace ov::intel_gpu::cm { + +class PagedAttentionCmImpl : public PrimitiveImplCM { +public: + DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::cm::PagedAttentionCmImpl) + + Stage::Ptr kv_cache_update = make_stage(); + Stage::Ptr pa_single_token = make_stage(); + Stage::Ptr pa_single_token_finalization = make_stage(); + Stage::Ptr pa_multi_token = make_stage(); + Stage::Ptr xattn_estimate_gemmqk = make_stage(); + Stage::Ptr xattn_estimate_find_block = make_stage(); + Stage::Ptr xattn_estimate_post_proc = make_stage(); + + PagedAttentionCmImpl() : PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) { + m_rt_params = std::make_unique(); + } + explicit PagedAttentionCmImpl(const kernel_impl_params& params) : PagedAttentionCmImpl() { + const auto desc = params.typed_desc(); + + GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::PagedAttentionCmImpl()" << std::endl; + add_stage(kv_cache_update, params); + add_stage(pa_single_token, params); + add_stage(pa_single_token_finalization, params); + add_stage(pa_multi_token, params); + if (desc->has_xattention) { + add_stage(xattn_estimate_gemmqk, params); + add_stage(xattn_estimate_find_block, params); + add_stage(xattn_estimate_post_proc, params); + } + } + + void update_xattn_rt_params(const kernel_impl_params& params) { + const auto desc = params.typed_desc(); + + // XAttention estimate is following afer kvcache_update. + auto out_shape = params.output_layouts[0].get_shape(); + const size_t block_size = get_xattn_block_size(params); + const size_t kv_len = get_max_context_len(params); + const size_t q_len = out_shape[0]; + const size_t N = kv_len / STRIDE; + const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N); + + const auto q_block_pad = ceil_div(q_len, block_size); + const auto sum_per_token_in_block = block_size / STRIDE; + const auto k_block_in_group = BLOCK_WG_N / sum_per_token_in_block; + const auto k_block_pad = k_block_in_group * N_kq_groups; + + auto rt_params = static_cast(m_rt_params.get()); + rt_params->q_block_pad = q_block_pad; + rt_params->k_block_pad = k_block_pad; + rt_params->q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM); + + const size_t head_size = desc->k_head_size; + + const auto M = q_len / STRIDE; //# will slient drop the tails which is less than `stride` + const auto K = STRIDE * head_size; + + const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M); + + rt_params->N_kq_groups = N_kq_groups; + rt_params->M = M; + rt_params->N = N; + rt_params->K = K; + rt_params->q_stride_pad = q_stride_pad; + } + + void update_rt_params(const primitive_inst& instance) override { + update_stages_flags(instance); + if (m_rt_params == nullptr) { + m_rt_params = std::make_unique(); + } + + const auto& params = *instance.get_impl_params(); + OPENVINO_ASSERT(!params.is_dynamic()); + auto rt_params = static_cast(m_rt_params.get()); + const auto& desc = params.typed_desc(); + + rt_params->stage = get_paged_attention_stage(params); + const auto max_context_len = get_max_context_len(params); + rt_params->max_context_len = max_context_len; + GPU_DEBUG_TRACE_DETAIL << "update_rt_params for stage: " << static_cast(rt_params->stage) << " max_context_len: " << rt_params->max_context_len + << std::endl; + + if (rt_params->stage == PagedAttentionStage::GENERATE) { + auto partition_size = get_partition_size(desc->has_xattention); + rt_params->num_of_partitions = ceil_div(max_context_len, partition_size); + + GPU_DEBUG_TRACE_DETAIL << " partition_size: " << partition_size << " num_of_partitions: " << rt_params->num_of_partitions << std::endl; + } else { + if (desc->has_xattention) { + update_xattn_rt_params(params); + } + } + } + + // update impl_parameter and rt_parameter + void update(primitive_inst& inst, const kernel_impl_params& impl_params) override { + PrimitiveImplCM::update(inst, impl_params); + update_rt_params(inst); + } + + event::ptr execute(const std::vector& events, primitive_inst& instance) override { + const auto& params = *instance.get_impl_params(); + const auto desc = params.typed_desc(); + + update_stages_flags(instance); + auto rt_params = static_cast(m_rt_params.get()); + OPENVINO_ASSERT(rt_params != nullptr); + + GPU_DEBUG_TRACE_DETAIL << "ov::intel_gpu::cm::PagedAttentionCmImpl::execute(): stage = " << static_cast(rt_params->stage) << std::endl; + std::vector res_event = events; + res_event = {execute_stage(res_event, instance, kv_cache_update)}; + + if (rt_params->stage == PagedAttentionStage::PREFILL || rt_params->stage == PagedAttentionStage::MIXED) { + if (has_stage(xattn_estimate_gemmqk) && !bypass_xattn(params)) { + res_event = {execute_stage(res_event, instance, xattn_estimate_gemmqk)}; + res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; + res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; + } + res_event = {execute_stage(res_event, instance, pa_multi_token)}; + } else if (rt_params->stage == PagedAttentionStage::GENERATE) { + res_event = {execute_stage(res_event, instance, pa_single_token)}; + res_event = {execute_stage(res_event, instance, pa_single_token_finalization)}; + } + return res_event[0]; + } + + bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override { + const auto stage = get_paged_attention_stage(impl_params); + + // In case of MIXED mode execution Paged Attention may require dispatch data update and internal + // buffers reallocation even if the input shapes haven't been changed. Therefore, check the current execution + // mode and update parameters if needed + return stage == PagedAttentionStage::MIXED; + } + + [[nodiscard]] std::vector get_internal_buffer_descs(const kernel_impl_params& params) const override { + std::vector internal_buffers; + + const auto desc = params.typed_desc(); + auto rt_params = static_cast(m_rt_params.get()); + // Assume rt_params are updated, because get_internal_buffer_descs surely occurs after update_rt_params. + OPENVINO_ASSERT(rt_params != nullptr); + + const auto stage = rt_params->stage; + GPU_DEBUG_TRACE_DETAIL << " stage = " << static_cast(stage) << std::endl; + if (stage == PagedAttentionStage::GENERATE) { + OPENVINO_ASSERT(rt_params->num_of_partitions != 0); + size_t num_of_partitions = rt_params->num_of_partitions; + + const auto& input = params.input_layouts[0]; + const int64_t total_tokens = input.get_partial_shape()[0].get_length(); + auto buf_elements_count = static_cast(total_tokens * desc->heads_num * num_of_partitions); + auto tmp_out_elements_count = static_cast(total_tokens * desc->heads_num * desc->v_head_size * num_of_partitions); + + internal_buffers.emplace_back(tmp_out_elements_count, ov::element::f32); // 0: intermediate partition output + internal_buffers.emplace_back(buf_elements_count, ov::element::f32); // 1: softmax exp_sums + + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: tmp_out=" << tmp_out_elements_count * 4 << " exp_sums=" << buf_elements_count * 4 << std::endl; + } else { + internal_buffers.emplace_back(16, ov::element::f32); // 0: intermediate partition output + internal_buffers.emplace_back(16, ov::element::f32); // 1: softmax exp_sums + + // internal buffer for XAttention + if (desc->has_xattention) { + auto count_kq_max_wg = static_cast(desc->heads_num * rt_params->N_kq_groups * rt_params->q_stride_pad); + internal_buffers.emplace_back(count_kq_max_wg, ov::element::f32); // 2: kq_max_wg + + auto count_kq_exp_partial_sum = static_cast(desc->heads_num * rt_params->q_stride_pad * rt_params->k_block_pad); + internal_buffers.emplace_back(count_kq_exp_partial_sum, ov::element::f32); // 3: kq_exp_partial_sum + + auto count_elements_mask = static_cast(desc->heads_num * rt_params->q_block_pad * rt_params->k_block_pad); + internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask + + auto count_elements_mask_merged = static_cast(desc->heads_num * rt_params->q_block_pad_merged * rt_params->k_block_pad); + internal_buffers.emplace_back(count_elements_mask_merged, ov::element::boolean); // 5: sparse_block_mask_wg + + GPU_DEBUG_TRACE_DETAIL << " internal buffer sizes: count_kq_max_wg=" << count_kq_max_wg * 4 + << " count_kq_exp_partial_sum=" << count_kq_exp_partial_sum * 4 << " count_elements_mask=" << count_elements_mask * 1 + << " count_elements_mask_merged=" << count_elements_mask_merged * 1 << std::endl; + } + } + + return internal_buffers; + } + + [[nodiscard]] std::unique_ptr clone() const override { + return make_deep_copy(this); + } +}; + +std::unique_ptr PagedAttentionImplementationManager::create_impl(const program_node& node, const kernel_impl_params& params) const { + OPENVINO_ASSERT(node.is_type()); + try { + return std::make_unique(params); + } catch (const std::exception& e) { + OPENVINO_THROW("Failed to create PagedAttentionCmImpl: ", e.what()); + } +} + +} // namespace ov::intel_gpu::cm +// BIND_BINARY_BUFFER_WITH_TYPE(cldnn::paged_attention) +BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::cm::PagedAttentionCmImpl) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp new file mode 100644 index 00000000000000..7d5c5693691a3e --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp @@ -0,0 +1,86 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "intel_gpu/runtime/layout.hpp" +#include "paged_attention_inst.h" +#include "program_node.h" +#include "registry/implementation_manager.hpp" + +using namespace cldnn; // TODO: Remove once namespaces are aligned + +namespace ov::intel_gpu::cm { + +struct PagedAttentionImplementationManager : public ImplementationManager { + OV_GPU_PRIMITIVE_IMPL("cm::paged_attention::opt") + explicit PagedAttentionImplementationManager(shape_types shape_type, ValidateFunc vf = nullptr) + : ImplementationManager(impl_types::cm, shape_type, std::move(vf)) {} + [[nodiscard]] std::unique_ptr create_impl(const program_node& node, const kernel_impl_params& params) const override; + [[nodiscard]] bool validate_impl(const program_node& node) const override { + static constexpr std::array supported_q_types = { + ov::element::f16, + }; + static constexpr std::array supported_kv_types = { + ov::element::f16, + ov::element::i8, + }; + + // Enable CM PA only in case of XAttention been enabled. May decouple them in future. + auto desc = node.as().get_primitive(); + if (!desc->has_xattention) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false because we enable CM PA when XAttention is enabled. " << std::endl; + return false; + } + + // TODO: Remove this limitation when PA CM kernel supports more "heads_num / kv_heads_num" cases. + // PA 2nd token CM kernel only supports case of "heads_num / kv_heads_num <= 8" + if (desc->heads_num / desc->kv_heads_num > 8) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false because heads_num / kv_heads_num > 8. " << std::endl; + return false; + } + + auto& engine = node.get_program().get_engine(); + const auto& config = node.get_program().get_config(); + const auto& info = engine.get_device_info(); + // CM optimized for systolic-array architectures + if (!check_cm_jit_support(engine, config) || !info.supports_immad || !config.get_use_cm()) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported GPU architecture. " << std::endl; + return false; + } + + const auto& q_layout = node.get_input_layout(PagedAttentionInputIdx::QUERY); + const auto& k_layout = node.get_input_layout(PagedAttentionInputIdx::KEY); + const auto& v_layout = node.get_input_layout(PagedAttentionInputIdx::VALUE); + const auto& out_layout = node.get_output_layout(0); + if (!everyone_is(format::bfyx, q_layout.format, k_layout.format, v_layout.format, out_layout.format)) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported qkv layout. " << std::endl; + return false; + } + + if (!one_of(k_layout.data_type, supported_q_types) || !one_of(v_layout.data_type, supported_q_types)) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported kv data type. " << std::endl; + return false; + } + + if (!one_of(q_layout.data_type, supported_q_types) || !one_of(out_layout.data_type, supported_q_types)) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported q/out data type. " << std::endl; + return false; + } + + const auto& kcache_layout = node.get_input_layout(PagedAttentionInputIdx::KEY_CACHE); + const auto& vcache_layout = node.get_input_layout(PagedAttentionInputIdx::VALUE_CACHE); + if (!one_of(kcache_layout.data_type, supported_kv_types) || !one_of(vcache_layout.data_type, supported_kv_types)) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false due to unsupported kv cache data type. " << std::endl; + return false; + } + + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - true" << std::endl; + return true; + } +}; +} // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp new file mode 100644 index 00000000000000..ffcea9c6f51eba --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp @@ -0,0 +1,798 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "paged_attention_gen.hpp" + +#include +#include +#include +#include + +#include "intel_gpu/primitives/paged_attention.hpp" +#include "intel_gpu/runtime/memory.hpp" +#include "openvino/core/partial_shape.hpp" +#include "paged_attention_inst.h" +#include "primitive_cm_base.hpp" +#include "primitive_inst.h" + +namespace ov::intel_gpu::cm { + +using namespace ov; +using namespace ov::intel_gpu::ocl; +using namespace cldnn; +namespace { +constexpr size_t WG_SIZE = 16; +constexpr size_t reduce_split_step = 16; +} // namespace + +#define DEBUG_ENABLED 0 + +// This function returns the kv_step and kv_split_len based on the architecture. +// return {kv_step, kv_split_len} +inline std::pair get_kv_split_size(size_t arch) { + if (arch == 1) { + return {8, 32}; // For Xe1 + } else if (arch == 2) { + return {16, 32}; // For Xe2 + } + OPENVINO_ASSERT(false, "Unsupported architecture for KV split size"); + return {0, 0}; // Fallback case, should not be reached +} + +inline size_t get_q_step(size_t arch, bool is_single_token = false) { + if (arch == 1) { + return is_single_token ? 1 : 8; // For Xe1 + } else if (arch == 2) { + // For Xe2, q_step = CM_GRF_WIDTH / 32 + return is_single_token ? 1 : 16; // For Xe2 + } + OPENVINO_ASSERT(false, "Unsupported architecture for Q step"); + return 0; // Fallback case, should not be reached +} + +inline size_t get_kv_len(const RuntimeParams& params, const PagedAttentionStage& stage) { + if (stage == PagedAttentionStage::PREFILL) { + auto key_shape = params.input_layouts[PagedAttentionInputIdx::KEY].get_shape(); + const size_t kv_len = key_shape[key_shape.size() - 2]; + return kv_len; + } else { + // key_cache shape = [block_num, head_num, block_size(128), head_size] + auto key_cache_shape = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE].get_shape(); + const size_t kv_len = key_cache_shape[0] * key_cache_shape[2]; + return kv_len; + } + OPENVINO_ASSERT(false, "Unsupported PagedAttentionStage for get_kv_len"); + return 0; // Fallback case, should not be reached +} + +inline size_t get_input_kv_len(const RuntimeParams& params) { + auto key_shape = params.input_layouts[PagedAttentionInputIdx::KEY].get_shape(); + const size_t kv_len = key_shape[key_shape.size() - 2]; + return kv_len; +} + +inline bool get_kv_compressed(const RuntimeParams& params) { + auto key_cache_layout = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE]; + if (data_type_traits::is_i8_u8(key_cache_layout.data_type)) { + return true; + } else { + return false; + } +} + +size_t get_partition_size(const bool has_xattention) { + if (!has_xattention && PA_KV_CACHE_BLOCK_SIZE < 128) { + return 128; + } else { + return PA_KV_CACHE_BLOCK_SIZE_XATTN; + } +} + +// max_context_len = max(past_lens + prompt_lens) +size_t get_max_context_len(const kernel_impl_params& params) { + const auto& input_mem = params.memory_deps; + const auto max_context_len = input_mem.at(PagedAttentionInputIdx::MAX_CONTEXT_LEN); + mem_lock max_context_len_mem_lock(max_context_len, *params.strm); + const auto paged_attention_max_len = max_context_len_mem_lock[0]; + return paged_attention_max_len; +} + +size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx) { + const auto& input_mem = params.memory_deps; + const auto past_len = input_mem.at(PagedAttentionInputIdx::PAST_LENS); + mem_lock past_len_mem_lock(past_len, *params.strm); + const auto paged_attention_past_len = past_len_mem_lock[seq_idx]; + return paged_attention_past_len; +} + +// TODO: change xattn_thresh from scaler to memory... once we remove the converter node +// between parameter node "xattention_threshold.xxx" and paged_attention node. +float get_xattn_thresh(const kernel_impl_params& params, const size_t seq_idx) { + const auto& input_mem = params.memory_deps; + const auto threshold_mem = input_mem.at(PagedAttentionInputIdx::XATTENTION_THRESHOLD); + mem_lock lock(threshold_mem, *params.strm); // converted + const auto thresh = static_cast(lock[seq_idx]); + return thresh; +} + +// Bypass xattn stages in the following conditions - +// either threshold is larger than 1.0, or, q_len is too small +// to compute xattn block_mask. +bool bypass_xattn(const kernel_impl_params& params) { + bool bypass = false; + bool allow_bypass = params.get_program().get_config().get_allow_bypass_xattn(); + if (allow_bypass) { + auto xattn_thresh = get_xattn_thresh(params); + bypass = xattn_thresh >= 1.0; + } + + auto q_len = params.output_layouts[0].get_shape()[0]; + bypass |= q_len < static_cast(STRIDE); //# will slient drop the tails which is less than `stride` + return bypass; +} + +PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) { + const auto& query_shape = impl_param.get_input_layout(PagedAttentionInputIdx::QUERY).get_partial_shape(); + const auto& past_lens_shape = impl_param.get_input_layout(PagedAttentionInputIdx::PAST_LENS).get_partial_shape(); + + if (query_shape.is_static() && past_lens_shape.is_static()) { + if (query_shape[0].get_length() == past_lens_shape[0].get_length()) { + return PagedAttentionStage::GENERATE; + } + + const auto& memory_deps = impl_param.memory_deps; + const auto past_lens_mem = memory_deps.at(PagedAttentionInputIdx::PAST_LENS); + mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); + + const auto past_lens_size = past_lens_mem_lock.size(); + for (size_t i = 0; i < past_lens_size; i++) { + if (past_lens_mem_lock[i] != 0) { + return PagedAttentionStage::MIXED; + } + } + return PagedAttentionStage::PREFILL; + } + return PagedAttentionStage::UNKNOWN; +} + +//----------------------------------------------------------------------------------------------------------------- +// Base generator +//----------------------------------------------------------------------------------------------------------------- +JitConstants PagedAttentionGeneratorBase::get_jit_constants(const kernel_impl_params& params) const { + auto jit = KernelGenerator::get_jit_constants(params); + jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); + + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + jit.make("XE_ARCH", xe_arch); + + auto split_size = get_kv_split_size(xe_arch); + jit.make("KV_STEP", split_size.first); + + jit.make("WG_SIZE", WG_SIZE); + jit.make("CAUSAL_MASK", 1); + return jit; +} + +//----------------------------------------------------------------------------------------------------------------- +// KV cache update generator +//----------------------------------------------------------------------------------------------------------------- +JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + + const auto desc = params.typed_desc(); + jit.make("KV_HEADS_NUM", desc->kv_heads_num); + jit.make("K_HEAD_SIZE", desc->k_head_size); + jit.make("V_HEAD_SIZE", desc->v_head_size); + if (desc->has_xattention) { + jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE_XATTN); + } else { + jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + } + + if (get_kv_compressed(params)) { + jit.make("KV_CACHE_COMPRESSION_PER_TOKEN", 1); + jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size + 4); + jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size + 4); + } else { + jit.make("KV_CACHE_COMPRESSION_PER_TOKEN", 0); + jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size); + jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size); + } + + return jit; +} + +Arguments PagedAttentionGeneratorKVCacheUpdate::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE}); // keys cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // values cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // keys cache + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // key_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // key_offset + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // value_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // value_offset + args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // batch_size_in_sequences + return args; +} + +DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + OPENVINO_ASSERT(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + const auto desc = params.typed_desc(); + + const size_t kv_len = get_input_kv_len(params); + const size_t kv_heads_num = desc->kv_heads_num; + const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE; + + wgs.global = {1, kv_heads_num, wg_count * WG_SIZE}; + wgs.local = {1, 1, WG_SIZE}; + + auto& scalars = kd.params.scalars; + size_t key_pitch = desc->k_head_size * kv_heads_num; + size_t value_pitch = desc->v_head_size * kv_heads_num; + auto key_layout = params.input_layouts[PagedAttentionInputIdx::KEY]; + auto value_layout = params.input_layouts[PagedAttentionInputIdx::VALUE]; + + auto get_simple_pitch = [](const layout& layout) { + size_t pitch = 1; + auto dims_padding = layout.get_padded_dims(); + for (size_t i = dims_padding.size() - 1; i > 0; --i) { + pitch = dims_padding[i]; + if (pitch > 1) { + break; + } + } + return pitch; + }; + key_pitch = get_simple_pitch(key_layout); + value_pitch = get_simple_pitch(value_layout); + + auto get_simple_offset = [](const layout& layout) { + size_t offset = 0; + const auto& data_padding = layout.data_padding; + const auto& lower_pads = data_padding._lower_size; + for (auto& it : lower_pads) { + if (it > 0) { + offset = it; + break; + } + } + return offset; + }; + size_t key_offset = get_simple_offset(key_layout); + size_t value_offset = get_simple_offset(value_layout); + + if (DEBUG_ENABLED) { // Debug + std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: " << "kv_len: " << kv_len << ", key_pitch: " << key_pitch + << ", key_offset: " << key_offset << ", value_pitch: " << value_pitch << ", value_offset: " << value_offset << ", gws: [" << wgs.global[0] + << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] + << "]" << std::endl; + } + // TODO: support multiple sequences + size_t batch_size_in_sequences = 1; + std::vector scaler_value = {key_pitch, key_offset, value_pitch, value_offset, batch_size_in_sequences}; + scalars.resize(scaler_value.size()); + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// multi token generator +//----------------------------------------------------------------------------------------------------------------- +Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_impl_params& params) const { + const auto desc = params.typed_desc(); + + Arguments args; + // Doesn't support Query with dynamic_padding + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // query + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // key_cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // value_cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // past_lens + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block_indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block_indices_begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins + + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); + + if (desc->has_xattention) { + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK_MERGED}); // sparse_block_mask_wg + } + + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len + if (desc->has_xattention) { + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // validate + } + return args; +} + +JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + const auto desc = params.typed_desc(); + const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)); + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + + const size_t xattn_block_size = get_xattn_block_size(params); + + jit.make("CMFLA_NUM_HEADS", desc->heads_num); + jit.make("CMFLA_NUM_KV_HEADS", desc->kv_heads_num); + jit.make("CMFLA_HEAD_SIZE", desc->k_head_size); + jit.add(make_jit_constant("CMFLA_SCALE_FACTOR", scale_factor)); + jit.make("CMFLA_IS_CAUSAL", 1); + if (desc->has_xattention) { + jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE_XATTN); + jit.make("SPARSE_BLOCK_SIZE", xattn_block_size); + } else { + jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE); + jit.make("SPARSE_BLOCK_SIZE", 1); + } + jit.make("Q_STEP", get_q_step(xe_arch, true)); + + if (get_kv_compressed(params)) { + jit.make("CMPA_KVCACHE_U8", 1); + } else { + jit.make("CMPA_KVCACHE_U8", 0); + } + return jit; +} + +DispatchDataFunc PagedAttentionGeneratorMultiToken::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + auto& wgs = kd.params.workGroups; + auto& scalars = kd.params.scalars; + auto desc = params.typed_desc(); + auto rtp = static_cast(rt_params); + // OPENVINO_ASSERT(rt_params != nullptr); + const size_t heads_num = desc->heads_num; + auto query_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; + + auto out_shape = params.output_layouts[0].get_shape(); + const size_t batch = out_shape.size() < 4 ? 1 : out_shape[0]; + const size_t q_len = out_shape[0]; + + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + const size_t q_step = get_q_step(xe_arch, false); + const size_t wg_seq_len = WG_SIZE * q_step; + const size_t wg_count = align_to(q_len, wg_seq_len) / wg_seq_len; + + wgs.global = {batch, heads_num, wg_count * WG_SIZE}; + wgs.local = {1, 1, WG_SIZE}; + + if (DEBUG_ENABLED) { // Debug + std::cout << "PagedAttentionGeneratorMultiToken::get_dispatch_data_func: \n" + << "\tbatch: " << batch << ", heads_num: " << heads_num << ", q_len: " << q_len << ", q_step: " << q_step + << ", wg_seq_len: " << wg_seq_len << ", wg_count: " << wg_count << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " + << wgs.global[2] << "]" << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + } + auto num_scalers = desc->has_xattention ? 4 : 1; + scalars.resize(num_scalers); + scalars[0].t = ScalarDescriptor::Types::INT32; + scalars[0].v.s32 = static_cast(q_len); + if (num_scalers > 1) { + scalars[1].t = ScalarDescriptor::Types::INT32; + scalars[1].v.s32 = static_cast(rtp->q_block_pad); + + scalars[2].t = ScalarDescriptor::Types::INT32; + scalars[2].v.s32 = static_cast(rtp->k_block_pad); + + scalars[3].t = ScalarDescriptor::Types::UINT8; + const bool validate = !bypass_xattn(params); + scalars[3].v.u8 = static_cast(validate); // validate depending on xattn_threshold + } + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// single token generator +//----------------------------------------------------------------------------------------------------------------- +JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + // jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); + auto desc = params.typed_desc(); + const float scale_factor = 1.0 / std::sqrt(static_cast(desc->k_head_size)); + const size_t kv_partition_size = get_partition_size(desc->has_xattention); + auto xe_arch = params.get_device_info().arch < gpu_arch::xe2 ? 1 : 2; + + jit.make("KV_PARTITION_SIZE", kv_partition_size); + if (desc->has_xattention) { + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE_XATTN); + } else { + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE); + } + jit.add(make_jit_constant("SCALE_FACTOR", scale_factor)); + jit.make("HEAD_SIZE", desc->k_head_size); + jit.make("HEADS_NUM", desc->heads_num); + jit.make("KV_HEADS_NUM", desc->kv_heads_num); + jit.make("Q_STEP", get_q_step(xe_arch, true)); + + if (get_kv_compressed(params)) { + jit.make("KV_CACHE_COMPRESSION", 1); + } else { + jit.make("KV_CACHE_COMPRESSION", 0); + } + + return jit; +} + +Arguments PagedAttentionGeneratorSingleToken::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + const auto desc = params.typed_desc(); + // const auto has_scale_input = !desc->scale_val.has_value(); + const auto has_scores_output = params.output_layouts.size() > 1; + OPENVINO_ASSERT(!has_scores_output, "[GPU][CM] PagedAttentionGeneratorSingleToken with scores output is not supported yet"); + + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // keys cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::VALUE_CACHE}); // values cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::PAST_LENS}); // past lens + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence begins + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_PARTITIONOUT}); // partition output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse output + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len==1 + + return args; +} + +DispatchDataFunc PagedAttentionGeneratorSingleToken::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + OPENVINO_ASSERT(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + const auto desc = params.typed_desc(); + auto rtp = static_cast(rt_params); + OPENVINO_ASSERT(rt_params != nullptr); + + const size_t batch = params.input_layouts[0].get_partial_shape()[0].get_length(); + const size_t heads_num = desc->heads_num; + const size_t kv_heads_num = desc->kv_heads_num; + const size_t partition_num = rtp->num_of_partitions; + + wgs.global = {batch, kv_heads_num, partition_num}; + wgs.local = {1, 1, 1}; + + // generate stage: q_len=1 + auto& scalars = kd.params.scalars; + std::vector scaler_value = {1}; + scalars.resize(scaler_value.size()); + + if (DEBUG_ENABLED) { // Debug + size_t kv_len = get_kv_len(params, PagedAttentionStage::GENERATE); + size_t max_context_len = get_max_context_len(params); + size_t past_len = get_past_len(params, 0); + std::cout << "PagedAttentionGeneratorSingleToken::get_dispatch_data_func: " << "batch: " << batch << ", heads_num: " << heads_num + << ", partition_num: " << partition_num << ", kv_len: " << kv_len << ", max_context_len = " << max_context_len + << ", past_len = " << past_len << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" << ", lws: [" + << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + } + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// single token finalization generator +//----------------------------------------------------------------------------------------------------------------- +JitConstants PagedAttentionGeneratorSingleTokenFinalization::get_jit_constants(const kernel_impl_params& params) const { + auto jit = PagedAttentionGeneratorBase::get_jit_constants(params); + const auto desc = params.typed_desc(); + + jit.make("REDUCE_SPLIT_SIZE", reduce_split_step); + jit.make("HEAD_SIZE", desc->k_head_size); + jit.make("HEADS_NUM", desc->heads_num); + return jit; +} + +Arguments PagedAttentionGeneratorSingleTokenFinalization::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + const auto has_scores_output = params.output_layouts.size() > 1; + if (has_scores_output) + OPENVINO_THROW("[GPU][CM] PagedAttentionGeneratorSingleTokenFinalization with scores output is not supported yet"); + + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_PARTITIONOUT}); // partition data + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // output + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::DECODE_EXPSUMS}); // lse + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // kv_partition_num + + return args; +} + +DispatchDataFunc PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_data_func() const { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + OPENVINO_ASSERT(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + const auto desc = params.typed_desc(); + auto rtp = static_cast(rt_params); + + OPENVINO_ASSERT(rt_params != nullptr); + + const size_t batch = params.input_layouts[0].get_partial_shape()[0].get_length(); + const size_t heads_num = desc->heads_num; + const size_t head_size = desc->k_head_size; + wgs.global = {batch, heads_num, head_size / reduce_split_step}; + wgs.local = {1, 1, 1}; + + auto& scalars = kd.params.scalars; + const size_t partition_num = rtp->num_of_partitions; + std::vector scaler_value = {partition_num}; + scalars.resize(scaler_value.size()); + + if (DEBUG_ENABLED) { // Debug + std::cout << "PagedAttentionGeneratorSingleTokenFinalization::get_dispatch_data_func: " << "batch: " << batch << ", heads_num: " << heads_num + << ", partition_num: " << partition_num << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]" + << ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl; + } + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// Helpers of XAttention +//----------------------------------------------------------------------------------------------------------------- + +//----------------------------------------------------------------------------------------------------------------- +// Base generator of XAttention +//----------------------------------------------------------------------------------------------------------------- +JitConstants XAttentionEstimateGeneratorBase::get_jit_constants(const kernel_impl_params& params) const { + auto jit = KernelGenerator::get_jit_constants(params); + jit.add(make_jit_constant("KERNEL_NAME", get_entry_point(params))); + + auto desc = params.typed_desc(); + + const float scale_factor = 1.0f / std::sqrt(static_cast(desc->k_head_size)) / STRIDE; + int scale_factor_i; + std::memcpy(static_cast(&scale_factor_i), &scale_factor, sizeof(scale_factor)); + + const uint32_t wg_k = BLOCK_WG_M; + const uint32_t wg_q = BLOCK_WG_N; + const size_t block_size = get_xattn_block_size(params); + OPENVINO_ASSERT(wg_k % block_size == 0, "wg_k should be multiple of block_size then there is no tails from block_size"); + OPENVINO_ASSERT(wg_q % block_size == 0, "wg_q should be multiple of block_size then there is no tails from block_size"); + + jit.make("STRIDE", STRIDE); + jit.make("HQ", desc->heads_num); + jit.make("HK", desc->kv_heads_num); + jit.make("HEAD_SIZE", desc->k_head_size); + jit.make("SG_M", SG_M); + jit.make("SG_N", SG_N); + jit.make("BLOCK_SG_M", BLOCK_SG_M); + jit.make("BLOCK_SG_N", BLOCK_SG_N); + jit.make("BLOCK_SIZE", block_size); + jit.make("KV_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE_XATTN); + jit.add(make_jit_constant("INV_S", scale_factor_i)); + jit.make("BLOCK_SHARE_MAX", BLOCK_WG_N); + //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... + jit.make("WALK_HQ", desc->heads_num != desc->kv_heads_num ? 2 : 1); + jit.make("IS_CAUSAL", 1); + if (get_kv_compressed(params)) { + jit.make("USE_INT8", 1); + jit.make("HEAD_SIZE_KEY", desc->k_head_size + 2 * 2); + } else { + jit.make("USE_INT8", 0); + jit.make("HEAD_SIZE_KEY", desc->k_head_size); + } + jit.make("SOFTMAX_TYPE", "float"); + + return jit; +} + +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate gemm_qk generator +//----------------------------------------------------------------------------------------------------------------- +Arguments XAttentionEstimateGEMMQK::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::KEY_CACHE}); // keys cache + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::QUERY}); // queries + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES}); // block indices + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::BLOCK_INDICES_BEGINS}); // block indices begins + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_EXPSUMS}); // kq_exp_partial_sum + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // M + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // N + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // K + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // query_pitch + args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // slice_no + args.push_back({ArgumentDescriptor::Types::SCALAR, 5}); // slice + args.push_back({ArgumentDescriptor::Types::SCALAR, 6}); // q_start_strided + + return args; +} + +DispatchDataFunc XAttentionEstimateGEMMQK::get_dispatch_data_func() const { + return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + OPENVINO_ASSERT(!params.is_dynamic()); + OPENVINO_ASSERT(rt_params != nullptr); + auto rtp = static_cast(rt_params); + const auto desc = params.typed_desc(); + + const auto M = rtp->M; + const auto N = rtp->N; + const auto K = rtp->K; + + auto get_simple_pitch = [](const layout& layout) { + size_t pitch = 1; + auto dims_padding = layout.get_padded_dims(); + for (size_t i = dims_padding.size() - 1; i > 0; --i) { + pitch = dims_padding[i]; + if (pitch > 1) { + break; + } + } + return pitch; + }; + auto querry_layout = params.input_layouts[PagedAttentionInputIdx::QUERY]; + const size_t query_pitch = get_simple_pitch(querry_layout) * STRIDE; + const size_t slice_no = 0, slice = 0; + + //# loop order walks HQ first and the step is WALK_HQ, 1 means not walk HQ, 2 means walks 2 heads first. Valid value: 1, 2, 4... + const size_t WALK_HQ = desc->heads_num != desc->kv_heads_num ? 2 : 1; + + auto& wgs = kd.params.workGroups; + wgs.global = {rtp->N_kq_groups * (rtp->q_stride_pad / BLOCK_WG_M) * SG_N * WALK_HQ, SG_M, desc->heads_num / WALK_HQ}; + wgs.local = {SG_N, SG_M, 1}; + + const size_t q_start_strided = N - M; + OPENVINO_ASSERT(N >= M, "length of key cache must be greater or equal than query"); + + auto& scalars = kd.params.scalars; + std::vector scaler_value = {M, N, K, query_pitch, slice_no, slice, q_start_strided}; + scalars.resize(scaler_value.size()); + + for (size_t i = 0; i < scaler_value.size(); ++i) { + if (i == 4 || i == 5) { + scalars[i].t = ScalarDescriptor::Types::INT32; + scalars[i].v.s32 = static_cast(scaler_value[i]); + } else { + scalars[i].t = ScalarDescriptor::Types::UINT32; + scalars[i].v.u32 = static_cast(scaler_value[i]); + } + } + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate find_block generator +//----------------------------------------------------------------------------------------------------------------- +Arguments XAttentionEstimateFindBlock::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + + // inputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_MAX}); // kq_max_wg + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_GEMMQK_EXPSUMS}); // kq_exp_partial_sum + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_len + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_stride + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // q_stride_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 3}); // q_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 4}); // k_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 5}); // causal_start_index + args.push_back({ArgumentDescriptor::Types::SCALAR, 6}); // thresh + + return args; +} + +DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const { + return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + OPENVINO_ASSERT(!params.is_dynamic()); + OPENVINO_ASSERT(rt_params != nullptr); + auto rtp = static_cast(rt_params); + const auto desc = params.typed_desc(); + + auto& wgs = kd.params.workGroups; + + const size_t heads_num = desc->heads_num; + + auto out_shape = params.output_layouts[0].get_shape(); + const size_t q_len = out_shape[0]; + + const size_t block_size = get_xattn_block_size(params); + const size_t sum_per_n_token_in_block = static_cast(block_size / STRIDE); + const uint32_t q_block = ceil_div(rtp->M, sum_per_n_token_in_block); + const uint32_t k_block = ceil_div(rtp->N, sum_per_n_token_in_block); + + const float xattn_thresh = get_xattn_thresh(params); + + wgs.global = {rtp->q_block_pad, heads_num, 1}; + wgs.local = {1, 1, 1}; + + auto& scalars = kd.params.scalars; + std::vector scaler_value = {q_len, rtp->M, rtp->q_stride_pad, rtp->q_block_pad, rtp->k_block_pad, k_block - q_block}; + scalars.resize(scaler_value.size() + 1); + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::UINT32; + scalars[i].v.u32 = static_cast(scaler_value[i]); + } + scalars[scaler_value.size()].t = ScalarDescriptor::Types::FLOAT32; // the last is for thresh with f32 dtype + scalars[scaler_value.size()].v.f32 = static_cast(xattn_thresh); + }}; +} + +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate post_proc generator +//----------------------------------------------------------------------------------------------------------------- +JitConstants XAttentionEstimatePostProc::get_jit_constants(const kernel_impl_params& params) const { + auto jit = XAttentionEstimateGeneratorBase::get_jit_constants(params); + + jit.make("MERGED_Q_NUM", MERGED_Q_NUM); + + return jit; +} + +Arguments XAttentionEstimatePostProc::get_arguments_desc(const kernel_impl_params& params) const { + Arguments args; + + // inputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK}); // sparse_block_mask + + // outputs + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, PagedAttentionInternBuffIdx::XATTN_BLOCKMASK_MERGED}); // sparse_block_mask_wg + + // scalar + args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_stride_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad + args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad + + return args; +} + +DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const { + return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + OPENVINO_ASSERT(!params.is_dynamic()); + OPENVINO_ASSERT(rt_params != nullptr); + auto rtp = static_cast(rt_params); + const auto desc = params.typed_desc(); + + auto& wgs = kd.params.workGroups; + + wgs.global = {rtp->q_block_pad_merged, desc->heads_num, 1}; + wgs.local = {1, 1, 1}; + + auto& scalars = kd.params.scalars; + std::vector scaler_value = {rtp->q_stride_pad, rtp->q_block_pad, rtp->k_block_pad}; + scalars.resize(scaler_value.size()); + + for (size_t i = 0; i < scaler_value.size(); ++i) { + scalars[i].t = ScalarDescriptor::Types::UINT32; + scalars[i].v.u32 = static_cast(scaler_value[i]); + } + }}; +} + +} // namespace ov::intel_gpu::cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp new file mode 100644 index 00000000000000..c630269d6f8696 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp @@ -0,0 +1,162 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "../ocl_v2/utils/jitter.hpp" +#include "common_utils/jitter.hpp" +#include "intel_gpu/graph/kernel_impl_params.hpp" +#include "intel_gpu/primitives/paged_attention.hpp" +#include "intel_gpu/runtime/layout.hpp" +#include "openvino/core/type.hpp" +#include "program_node.h" +#include "registry/implementation_manager.hpp" +#include "utils/kernel_generator.hpp" + +using namespace cldnn; // TODO: Remove once namespaces are aligned + +namespace ov::intel_gpu::cm { + +constexpr auto get_pa_build_options() { + return " -cmc -Qxcm_register_file_size=256"; +} + +// BLOCK_SIZE can be 16/256 for legacy and xattn cases respectively +#define PA_KV_CACHE_BLOCK_SIZE 16 +#define PA_KV_CACHE_BLOCK_SIZE_XATTN 256 + +constexpr uint32_t BLOCK_SG_M = 64; +constexpr uint32_t BLOCK_SG_N = 32; +constexpr uint32_t SG_M = 4; +constexpr uint32_t SG_N = 8; +constexpr uint32_t BLOCK_WG_M = BLOCK_SG_M * SG_M; +constexpr uint32_t BLOCK_WG_N = BLOCK_SG_N * SG_N; +constexpr int STRIDE = 16; +constexpr uint32_t XATTN_BLOCK_SIZE = 128; +constexpr uint32_t MERGED_Q_NUM = PA_KV_CACHE_BLOCK_SIZE_XATTN / XATTN_BLOCK_SIZE; // for xattn post_proc + +enum class PagedAttentionStage : uint8_t { GENERATE = 0, PREFILL = 1, MIXED = 2, UNKNOWN = 3 }; +struct PagedAttentionRuntimeParams : public ImplRuntimeParams { + PagedAttentionStage stage; + size_t max_context_len; + // below are rt params for decoding + size_t num_of_partitions; + // below are rt params for xattn + size_t q_block_pad; + size_t k_block_pad; + size_t q_stride_pad; + size_t q_block_pad_merged; + size_t N_kq_groups; + size_t M; + size_t N; + size_t K; +}; + +enum PagedAttentionInternBuffIdx { + // for decoding kernels + DECODE_PARTITIONOUT = 0, // 0: intermediate partition output + DECODE_EXPSUMS = 1, // 1: softmax exp_sums + // for xattn kernels + XATTN_GEMMQK_MAX = 2, // 2: kq_max_wg + XATTN_GEMMQK_EXPSUMS = 3, // 3: kq_exp_partial_sum + XATTN_BLOCKMASK = 4, // 4: sparse_block_mask + XATTN_BLOCKMASK_MERGED = 5, // 5: sparse_block_mask_wg +}; + +//----------------------------------------------------------------------------------------------------------------- +// Helpers of XAttention +//----------------------------------------------------------------------------------------------------------------- +int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size); +PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param); +size_t get_max_context_len(const kernel_impl_params& params); +size_t get_past_len(const kernel_impl_params& params, const size_t seq_idx); +size_t get_partition_size(const bool has_xattention); + +float get_xattn_thresh(const kernel_impl_params& impl_param, const size_t seq_idx = 0); +bool bypass_xattn(const kernel_impl_params& impl_param); +inline size_t get_xattn_block_size(const kernel_impl_params& impl_param) { + return XATTN_BLOCK_SIZE; +} + +class PagedAttentionGeneratorBase : public KernelGenerator { +public: + explicit PagedAttentionGeneratorBase(std::string_view kernel_name, std::string_view stage_suffix = "_cm") : KernelGenerator(kernel_name, stage_suffix) {} + [[nodiscard]] std::string get_build_options(const RuntimeParams& params) const override { + return KernelGenerator::get_build_options(params) + get_pa_build_options(); + } + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; +}; + +class PagedAttentionGeneratorKVCacheUpdate : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorKVCacheUpdate() : PagedAttentionGeneratorBase("pa_kv_cache_update_ref") {} + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class PagedAttentionGeneratorMultiToken : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorMultiToken() : PagedAttentionGeneratorBase("pa_multi_token") {} + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class PagedAttentionGeneratorSingleToken : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorSingleToken() : PagedAttentionGeneratorBase("pa_single_token") {} + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class PagedAttentionGeneratorSingleTokenFinalization : public PagedAttentionGeneratorBase { +public: + PagedAttentionGeneratorSingleTokenFinalization() : PagedAttentionGeneratorBase("pa_single_token_finalization") {} + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +//----------------------------------------------------------------------------------------------------------------- +// XAttention Estimate generators +//----------------------------------------------------------------------------------------------------------------- +class XAttentionEstimateGeneratorBase : public KernelGenerator { +public: + explicit XAttentionEstimateGeneratorBase(std::string_view kernel_name, std::string_view stage_suffix = "_cm") + : KernelGenerator(kernel_name, stage_suffix) {} + [[nodiscard]] std::string get_build_options(const RuntimeParams& params) const override { + return KernelGenerator::get_build_options(params) + get_pa_build_options(); + } + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; +}; +class XAttentionEstimateGEMMQK : public XAttentionEstimateGeneratorBase { +public: + XAttentionEstimateGEMMQK() : XAttentionEstimateGeneratorBase("xattn_gemm_qk") {} + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class XAttentionEstimateFindBlock : public XAttentionEstimateGeneratorBase { +public: + XAttentionEstimateFindBlock() : XAttentionEstimateGeneratorBase("xattn_find_block") {} + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +class XAttentionEstimatePostProc : public XAttentionEstimateGeneratorBase { +public: + XAttentionEstimatePostProc() : XAttentionEstimateGeneratorBase("xattn_post_proc") {} + [[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override; + [[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override; + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override; +}; + +} // namespace ov::intel_gpu::cm \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm new file mode 100644 index 00000000000000..d3a61789b312ea --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_find_block.cm @@ -0,0 +1,70 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "find_block.hpp" + +#ifndef ATTR +#define ATTR [[type("svmptr_t")]] +#define ATTR_BUF [[type("buffer_t")]] +#endif + +// _GENX_MAIN_ void find_block( +extern "C" _GENX_MAIN_ void KERNEL_NAME( + svmptr_t kq_max_wg ATTR, + svmptr_t kq_exp_partial_sum ATTR, + svmptr_t block_mask ATTR, + uint q_len, + uint q_stride, + uint q_stride_pad, + uint q_block_pad, + uint k_block_pad, + uint causal_start_index, + float thresh +#if DEBUG_ACC == 1 + , svmptr_t kq_sum ATTR +#endif +) { + // kq_max_wg: [b, hq, n_groups, q_stride_pad] + // kq_exp_partial_sum: [b, hq, q_stride_pad, k_block_pad] + // kq_sum: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] + // block_mask: [b, hq, q_stride_pad/TOKEN_IN_BLOCK, k_block_pad] + // [1, 32, 256], [1, 32, 64, 256], [1, 32, 256, 64 * 16], A_sum:[1, 32, 32, 64 * 16] + // global: [q_block_pad, hq, b] + const int TOKEN_IN_BLOCK = BLOCK_SIZE / STRIDE; + const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; + uint m = cm_group_id(0); + uint hq = cm_group_id(1); + uint b = cm_group_id(2); + kq_max_wg += (b * HQ + hq) * (k_block_pad / TOKEN_SHARE_MAX) * q_stride_pad * (uint)sizeof(SOFTMAX_TYPE); + kq_exp_partial_sum += (b * HQ + hq) * q_stride_pad * k_block_pad * (uint)sizeof(SOFTMAX_TYPE); +#if DEBUG_ACC == 1 + kq_sum += (b * HQ + hq) * (q_stride_pad / TOKEN_IN_BLOCK) * k_block_pad * (uint)sizeof(half); +#endif + block_mask += (b * HQ + hq) * q_block_pad * k_block_pad; + + const uint slm_size = 32 * 16 * sizeof(ushort); + cm_slm_init(slm_size); + auto slm = cm_slm_alloc(slm_size); + + find(slm, m, kq_max_wg, kq_exp_partial_sum, block_mask, q_len, q_stride, q_stride_pad, k_block_pad, thresh, causal_start_index +#if DEBUG_ACC == 1 + , kq_sum +#endif + ); +} + +} // NAMESPACE \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm new file mode 100644 index 00000000000000..24fb0ab4864b69 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_gemm_qk.cm @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "estimate.hpp" + +#define ABS(x) (x) < 0 ? -(x) : (x) + +CM_INLINE void get_mn(uint& id_wg_m, uint& id_wg_n, uint M, uint N, int slice_no, int slice, const int BLOCK_WG_M, const int BLOCK_WG_N) { + uint id_wg_mn = cm_group_id(0) / WALK_HQ; + if (slice_no == 0) { + if (slice == 0) { + // loop M first, N is shared, total = N/256*M+N + uint WG_MN = (M + BLOCK_WG_M - 1) / BLOCK_WG_M; + id_wg_m = id_wg_mn % WG_MN; + id_wg_n = id_wg_mn / WG_MN; + } else { + // loop N first, M is shared, total = M/128*N+M + uint WG_MN = (N + BLOCK_WG_N - 1) / BLOCK_WG_N; + id_wg_n = id_wg_mn % WG_MN; + id_wg_m = id_wg_mn / WG_MN; + } + } else { + uint wg_x = slice > 0 ? N / BLOCK_WG_N : M / BLOCK_WG_M; + uint slice_no_abs = ABS(slice_no); + uint slice_abs = ABS(slice); + int id_wg_mn_in_reminder = (int)id_wg_mn - (int)(slice_no_abs * slice_abs * wg_x); + uint slice_idx; + // in [slice_no x slice] + if (id_wg_mn_in_reminder < 0) { + slice_idx = id_wg_mn / (slice_abs * wg_x); + uint rem_in_slice = id_wg_mn % (slice_abs * wg_x); + uint x = rem_in_slice % slice_abs; + uint y = rem_in_slice / slice_abs; + id_wg_m = slice > 0 ? x + slice_idx * slice_abs : y; + id_wg_n = slice < 0 ? x + slice_idx * slice_abs : y; + } else { + uint slice_rem = slice_abs + (slice_no > 0 ? 1 : -1); + slice_idx = id_wg_mn_in_reminder / (slice_rem * wg_x); + uint rem_in_slice = id_wg_mn_in_reminder % (slice_rem * wg_x); + uint x = rem_in_slice % slice_rem; + uint y = rem_in_slice / slice_rem; + id_wg_m = slice > 0 ? x + slice_idx * slice_rem + slice_no_abs * slice_abs : y; + id_wg_n = slice < 0 ? x + slice_idx * slice_rem + slice_no_abs * slice_abs : y; + } + } +} + +extern "C" _GENX_MAIN_ void KERNEL_NAME( + svmptr_t key_cache ATTR, + svmptr_t query ATTR, + svmptr_t block_indices ATTR, + svmptr_t block_indices_begins ATTR, + svmptr_t kq_max_wg ATTR, + svmptr_t kq_exp_partial_sum ATTR, + uint M, uint N, uint K, uint query_stride, int slice_no, int slice, uint q_start_strided) { + const uint BLOCK_WG_M = BLOCK_SG_M * SG_M; + const uint BLOCK_WG_N = BLOCK_SG_N * SG_N; + const uint size_slm_b = 0; + uint hq = cm_group_id(2) * WALK_HQ; + hq += cm_group_id(0) & (WALK_HQ - 1); + if (hq >= HQ) return; + uint hk = hq / (HQ / HK); + const uint slm_size = SG_N * BLOCK_WG_M * sizeof(SOFTMAX_TYPE); + cm_slm_init(slm_size); + auto slm = cm_slm_alloc(slm_size); + + static_assert(HQ % HK == 0, "HQ must be multiple of HK"); + + uint id_wg_m, id_wg_n; + get_mn(id_wg_m, id_wg_n, M, N, slice_no, slice, BLOCK_WG_M, BLOCK_WG_N); + + // key cache: [block, HQ, KV_BLOCK_SIZE, HEAD_SIZE_KEY] +#if USE_INT8 + key_cache += hk * (KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(char)); +#else + key_cache += hk * (KV_BLOCK_SIZE * HEAD_SIZE_KEY * (uint)sizeof(half)); +#endif + // query: [l_q, HQ * HEAD_SIZE] + query += hq * HEAD_SIZE * (uint)sizeof(half); + + // kq_max: [hq, m_pad] + // kq_max_wg: [hq, n_groups, m_pad] + // kq_exp_partial_sum: [hq, m_pad, n_groups*BLOCK_WG_M/(BLOCK_SIZE/STRIDE)] + uint m_pad = (M + BLOCK_WG_M - 1) / BLOCK_WG_M * BLOCK_WG_M; + uint n_groups = (N + BLOCK_WG_N - 1) / BLOCK_WG_N; + kq_max_wg += hq * n_groups * m_pad * (uint)sizeof(SOFTMAX_TYPE); + + const uint sum_per_n_token_in_block = BLOCK_SIZE / STRIDE; + const uint n_after_sum_in_group = BLOCK_WG_N / sum_per_n_token_in_block; + const uint n_after_sum_pad = n_after_sum_in_group * n_groups; + kq_exp_partial_sum += hq * n_after_sum_pad * m_pad * (uint)sizeof(SOFTMAX_TYPE); + +#define CONCAT_IMPL(a, b) gemm_qk_ ##a ##x ##b ##_xe2 +#define CONCAT(x, y) CONCAT_IMPL(x, y) +#define FUNC CONCAT(BLOCK_SG_M, BLOCK_SG_N) + FUNC(id_wg_m, id_wg_n, hq, slm, key_cache, query, block_indices, block_indices_begins, kq_max_wg, kq_exp_partial_sum, M, N, K, query_stride, q_start_strided); +} + +} // NAMESPACE \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm new file mode 100644 index 00000000000000..0bbab306c2f5ef --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xattn_post_proc.cm @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { +#include "estimate.hpp" + +extern "C" _GENX_MAIN_ void KERNEL_NAME(svmptr_t block_mask ATTR, svmptr_t merged_block_mask ATTR, uint q_stride_pad, uint q_block_pad, uint k_block_pad) { + // block_mask: [b, hq, q_block_pad, k_block_pad] + // merged_block_mask: [b, hq, q_block_pad/MERGED_Q_NUM, k_block_pad] + // global: [q_block_pad/MERGED_Q_NUM, hq, b] + const int TOKEN_IN_BLOCK = BLOCK_SIZE / STRIDE; + const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK; + uint m_mereged = cm_group_id(0); + uint hq = cm_group_id(1); + uint b = cm_group_id(2); + block_mask += (b * HQ + hq) * q_block_pad * k_block_pad; + merged_block_mask += (b * HQ + hq) * cm_group_count(0) * k_block_pad; + merged_block_mask += m_mereged * k_block_pad; + block_mask += m_mereged * MERGED_Q_NUM * k_block_pad; + vector one = 1; + for (int j = 0; j < k_block_pad; j += 32) { + vector new_mask = cm_ptr_load((int*)block_mask, j).format(); + for (int i = 1; i < MERGED_Q_NUM; i++) { + if (m_mereged * MERGED_Q_NUM + i < q_stride_pad / TOKEN_IN_BLOCK) { + vector cur_mask = cm_ptr_load((int*)block_mask, j + i * k_block_pad).format(); + new_mask |= cur_mask; + } + } + cm_ptr_store((int*)merged_block_mask, j, new_mask.format()); + } +} + +} // NAMESPACE diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp index 6507e97188ff97..5fd64ded54411f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/primitive_ocl_base.hpp @@ -251,13 +251,9 @@ struct PrimitiveImplOCL : public cldnn::primitive_impl { GPU_DEBUG_TRACE_DETAIL << "\t" << i << ": type = " << static_cast(params.arguments[i].t) << ", index = " << params.arguments[i].index << '\n'; } - GPU_DEBUG_TRACE_DETAIL << "Memory buffers:" - << "shape_info=" << args.shape_info << " " - << "inputs=" << args.inputs.size() << " " - << "outputs=" << args.outputs.size() << " " - << "intermediates=" << args.intermediates.size() << " " - << "weights=" << args.weights << " " - << "scalars=" << (args.scalars ? args.scalars->size() : 0) << "\n"; + GPU_DEBUG_TRACE_DETAIL << "Memory buffers:" << "shape_info=" << args.shape_info << " " << "inputs=" << args.inputs.size() << " " + << "outputs=" << args.outputs.size() << " " << "intermediates=" << args.intermediates.size() << " " + << "weights=" << args.weights << " " << "scalars=" << (args.scalars ? args.scalars->size() : 0) << "\n"; stream.set_arguments(*stage.kernel, params, args); kd.need_args_update = false; } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp index 8fccefbc9e5eae..d3c3c92f6f5a77 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp @@ -8,6 +8,7 @@ #include #include +#include "paged_attention_inst.h" #include "program_node.h" #include "registry/implementation_manager.hpp" @@ -30,6 +31,12 @@ struct PagedAttentionOpt : public ImplementationManager { ov::element::i8, }; + auto desc = node.as().get_primitive(); + if (desc->has_xattention) { + GPU_DEBUG_TRACE_DETAIL << "validate_impl() - false because XAttention is not supported with ocl. " << std::endl; + return false; + } + const auto& q_layout = node.get_input_layout(0); const auto& k_layout = node.get_input_layout(1); const auto& v_layout = node.get_input_layout(2); diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index cf4d1a0f1d8d00..dbe8a27f0a79c6 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -22,6 +22,7 @@ struct typed_program_node : public typed_program_node_base get_lockable_input_ids() const override { std::set input_ports = { PagedAttentionInputIdx::PAST_LENS, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS, + PagedAttentionInputIdx::XATTENTION_THRESHOLD, PagedAttentionInputIdx::MAX_CONTEXT_LEN }; if (typed_desc()->has_score_aggregation) diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index e4d566b2a87664..1b0bc58d79471a 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -11,8 +11,6 @@ namespace cldnn { GPU_DEFINE_PRIMITIVE_TYPE_ID(paged_attention) -constexpr size_t paged_attention::block_size; - layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*node*/, kernel_impl_params const& impl_param) { auto out_layout = impl_param.get_input_layout(0); @@ -34,22 +32,35 @@ std::vector paged_attention_inst::calc_output_layouts(paged_attention_no data_layout.data_padding = padding(); - const auto& key_cache_idx = cldnn::paged_attention::PagedAttentionInputIdx::KEY_CACHE; + const auto key_cache_idx = cldnn::paged_attention::PagedAttentionInputIdx::KEY_CACHE; const auto& key_cache_ps = impl_param.get_input_layout(key_cache_idx).get_partial_shape(); const auto& key_cache_quant_mode = impl_param.get_program().get_config().get_key_cache_quant_mode(); bool key_cache_compressed = impl_param.get_input_layout(key_cache_idx).data_type == ov::element::i8 || impl_param.get_input_layout(key_cache_idx).data_type == ov::element::u8; - auto expected_block_size = paged_attention::block_size; + auto expected_block_size = desc->has_xattention ? paged_attention::block_size_xattn : paged_attention::block_size; if (key_cache_compressed && key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) { expected_block_size += 4; } OPENVINO_ASSERT((key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL) == desc->is_key_by_channel, "[GPU] Paged Attention key cache quantization mode mismatch: prim.is_key_by_channel : ", desc->is_key_by_channel, " but exec_config : ", impl_param.get_program().get_config().get_key_cache_quant_mode()); + + const auto block_size_idx = desc->has_xattention ? 2 : 3; bool valid_block_size = key_cache_ps.is_dynamic() || - (key_cache_ps[key_cache_idx].get_length() == static_cast(expected_block_size)); + (key_cache_ps[block_size_idx].get_length() == static_cast(expected_block_size)); OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation for key cache quant mode " - , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[key_cache_idx].get_length()); + , key_cache_quant_mode, ". Expected ", expected_block_size, ", but got ", key_cache_ps[block_size_idx].get_length()); + + // TODO: as a preview feature, only single sequence is supported so far. Will remove this check once + // full function ready in near future. + if (desc->has_xattention) { + const auto& subseq_begins_ps = impl_param.get_input_layout(PagedAttentionInputIdx::SUBSEQUENCE_BEGINS).get_partial_shape(); + bool valid_subseq_count = subseq_begins_ps.is_dynamic() || + (subseq_begins_ps[0].get_length() == static_cast(2)); + if (!valid_subseq_count) + OPENVINO_THROW("[GPU] Unexpected sub sequences count for XAttention. Got ", subseq_begins_ps[0].get_length() - 1); + } + std::vector output_layouts{ data_layout }; if (desc->has_scores_output()) { @@ -123,7 +134,9 @@ paged_attention_inst::typed_primitive_inst(network& network, const paged_attenti } OPENVINO_ASSERT(heads_num % kv_heads_num == 0); - OPENVINO_ASSERT(k_head_size % pa_block_size == 0); - OPENVINO_ASSERT(v_head_size % pa_block_size == 0); + if (!desc->has_xattention) { + OPENVINO_ASSERT(k_head_size % pa_block_size == 0); + OPENVINO_ASSERT(v_head_size % pa_block_size == 0); + } } } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp index c1fc982350054e..6c68f8806a410e 100644 --- a/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp +++ b/src/plugins/intel_gpu/src/graph/registry/paged_attention_impls.cpp @@ -10,6 +10,10 @@ #include "impls/ocl_v2/sdpa/paged_attention_opt.hpp" #endif +#if OV_GPU_WITH_CM + #include "impls/cm/paged_attention.hpp" +#endif + namespace ov { namespace intel_gpu { @@ -18,6 +22,7 @@ using namespace cldnn; const std::vector>& Registry::get_implementations() { static const std::vector> impls = { OV_GPU_CREATE_INSTANCE_OCL(ocl::PagedAttentionOpt, shape_types::any) + OV_GPU_CREATE_INSTANCE_CM(cm::PagedAttentionImplementationManager, shape_types::any) }; return impls; diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index b83792f0f9fc52..1a169c51059e59 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -37,9 +37,19 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared auto key_cache_ps = op->get_input_partial_shape(3); auto value_cache_ps = op->get_input_partial_shape(4); - auto k_head_size = has_rt_params ? rt_info.at(k_head_size_id).as() : key_cache_ps[2].get_length(); + // We may fallback to dense attention mode if xattn is not supported by either GPU archieture or compiler. + // So we check block_size from value cache shape, instead of checking op input type, to determine if xatnn is enabled. + if (value_cache_ps[2].get_length() == cldnn::paged_attention::block_size_xattn) { + prim.has_xattention = true; + } + const auto k_head_size_idx = prim.has_xattention ? 3 : 2; + + auto k_head_size = has_rt_params ? rt_info.at(k_head_size_id).as() : key_cache_ps[k_head_size_idx].get_length(); auto v_head_size = has_rt_params ? rt_info.at(v_head_size_id).as() : value_cache_ps[3].get_length(); auto kv_heads_num = has_rt_params ? rt_info.at(num_k_heads_id).as() : key_cache_ps[1].get_length(); + if (prim.has_xattention) { + OPENVINO_ASSERT(k_head_size == v_head_size); + } // WA: in some cases, the query input may have a bounded dimension // Use input shape of the input node in such cases @@ -91,12 +101,6 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared prim.has_rotated_blocks = true; } - const size_t xattention_threshold_idx = cldnn::paged_attention::PagedAttentionInputIdx::XATTENTION_THRESHOLD; - auto xattention_threshold_input = ov::as_type_ptr(op->get_input_node_shared_ptr(xattention_threshold_idx)); - if (xattention_threshold_input && xattention_threshold_input->get_output_partial_shape(0).is_dynamic()) { - prim.has_xattention = true; - } - const size_t sinks_idx = cldnn::paged_attention::PagedAttentionInputIdx::SINKS; auto sinks_const = ov::as_type_ptr(op->get_input_node_shared_ptr(sinks_idx)); OPENVINO_ASSERT(sinks_const != nullptr); diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 68a920d38a02b5..c493169d0ca820 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -366,17 +366,15 @@ void TransformationsPipeline::apply(std::shared_ptr func) { #ifdef GPU_DEBUG_CONFIG if (!config.get_use_cm()) { - OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," - "as CM for usage is disabled. Enable it by setting environment variable OV_GPU_USE_CM=ON."); + OPENVINO_WARN("XAttention optimization is disabled because CM is not enabled. " + "To enable, set environment variable OV_GPU_USE_CM=ON."); return true; } #endif if (!check_cm_jit_support(engine, config)) { - OPENVINO_WARN("You may miss SDPAToVLSDPA optimization for QWenVL model," - "as current IGC version is not compatible to the CM kernel used. Enable it by update IGC." - "Please also make sure clangFEWrapper for CM is present by checking environment varibles like " - "CM_FE_DIR or LD_LIBRARY_PATH if you are using Linux."); + OPENVINO_WARN("SDPAToVLSDPA optimization for QWenVL model unavailable: IGC version incompatible with CM kernel. " + "Update IGC and ensure clangFEWrapper for CM is available (check CM_FE_DIR or LD_LIBRARY_PATH on Linux)."); return true; } @@ -510,36 +508,106 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // To handle this case, "KeepConstPrecision" is executed again. manager.register_pass(supported_woq_types, !device_info.supports_immad); - ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; - kv_cache_config.keyCachePrecision = config.get_kv_cache_precision(); - kv_cache_config.valueCachePrecision = config.get_kv_cache_precision(); - kv_cache_config.inferencePrecision = infer_precision; - kv_cache_config.keyCacheBlockSize = 16; - kv_cache_config.keyCacheDimOrder = {0, 1, 3, 2}; - kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); - kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; - kv_cache_config.valueCacheBlockSize = 16; - kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; - kv_cache_config.valueCacheQuantBychannel = false; - kv_cache_config.valueCacheGroupSize = 0; - - manager.register_pass(kv_cache_config, - [&infer_precision](const ov::element::Type& precision, - const bool bychannel, - const size_t group_num, - int64_t& head_size, - int64_t& block_size) { - if (bychannel) { - // TODO: need to handle group size != block size case - if (precision == ov::element::i8 || precision == ov::element::u8) { - block_size += infer_precision.size() * 2; - } - } else { - if (precision == ov::element::i8 || precision == ov::element::u8) { - head_size += infer_precision.size() * 2 * group_num; - } + { + // Disable XAttention if GPU Xe2/Xe3 architectures is unavaiable or IGC incompatiable. + auto check_xattn_gpu_compatibility = [&](void) -> bool { + auto& engine = m_context->get_engine(); + const auto& info = engine.get_device_info(); + if (info.arch != cldnn::gpu_arch::xe2 && info.arch != cldnn::gpu_arch::xe3) { // CM optimized for systolic-array architectures + return false; + } + +#ifdef GPU_DEBUG_CONFIG + if (!config.get_use_cm()) { + OPENVINO_WARN("XAttention optimization is disabled because CM is not enabled. " + "To enable, set environment variable OV_GPU_USE_CM=ON."); + return false; + } +#endif + + if (!check_cm_jit_support(engine, config)) { + OPENVINO_WARN("XAttention optimization unavailable: IGC version incompatible with CM kernel. " + "Update IGC and ensure clangFEWrapper for CM is available (check CM_FE_DIR or LD_LIBRARY_PATH on Linux)."); + return false; + } + + return true; + }; + + + // Check if XAttention is enabled by the user via GENAI. + // This is determined by inspecting the model parameters for XAttention configurations, + // which are introduced during the SDPAToPagedAttention pass. + bool use_xattention = false; + const auto& parameters = func->get_parameters(); + for (const auto& param : parameters) { + if (param->get_friendly_name() == "xattention_block_size") { + use_xattention = true; + break; } - }); + } + + if (use_xattention) { + // Throw exception if xattn is not supported by either GPU archieture or compiler. + if (!check_xattn_gpu_compatibility()) + OPENVINO_THROW("[GPU] XAttention is not supported by your current GPU architecture or IGC version. " + "Please either disable XAttention by following the GenAI guide, or switch to a GPU with Xe2/Xe3 " + "architecture and ensure the latest IGC is installed."); + } + + // KVCache layout with default attention - + // k: [num_blocks, num_kv_heads, head_size, block_size(16)] + // v: [num_blocks, num_kv_heads, block_size(16), head_size] + // KVCache layout with XAttention - + // k: [num_blocks, num_kv_heads, block_size(256), head_size] + // v: [num_blocks, num_kv_heads, block_size(256), head_size] + ov::pass::ConvertPagedAttnInputs::KVCacheConfig kv_cache_config; + const auto kv_cache_precision = config.get_kv_cache_precision(); + kv_cache_config.keyCachePrecision = kv_cache_precision; + kv_cache_config.valueCachePrecision = kv_cache_precision; + kv_cache_config.inferencePrecision = infer_precision; + if (use_xattention) { + kv_cache_config.keyCacheBlockSize = cldnn::paged_attention::block_size_xattn; + kv_cache_config.keyCacheDimOrder = {0, 1, 2, 3}; // default dim order of [num_blocks, num_kv_heads, block_size, head_size] + } else { + kv_cache_config.keyCacheBlockSize = cldnn::paged_attention::block_size; + kv_cache_config.keyCacheDimOrder = {0, 1, 3, 2}; + } + kv_cache_config.keyCacheQuantBychannel = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL); + kv_cache_config.keyCacheGroupSize = (config.get_key_cache_quant_mode() == ov::internal::CacheQuantMode::BY_CHANNEL) ? 16 : 0; + if (use_xattention) { + if (kv_cache_config.keyCacheQuantBychannel && + ((kv_cache_precision == ov::element::i8 || kv_cache_precision == ov::element::u8)) ) + OPENVINO_THROW("[GPU] XAttention does not currently support per-channel quantized key cache."); + + kv_cache_config.valueCacheBlockSize = cldnn::paged_attention::block_size_xattn; + kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; + } else { + kv_cache_config.valueCacheBlockSize = cldnn::paged_attention::block_size; + kv_cache_config.valueCacheDimOrder = {0, 1, 2, 3}; + } + kv_cache_config.valueCacheQuantBychannel = false; + kv_cache_config.valueCacheGroupSize = 0; + + manager.register_pass(kv_cache_config, + [&infer_precision](const ov::element::Type& precision, + const bool bychannel, + const size_t group_num, + int64_t& head_size, + int64_t& block_size) { + OPENVINO_ASSERT(precision != ov::element::dynamic, "[GPU] kv_cache precision should be specified."); + if (bychannel) { + // TODO: need to handle group size != block size case + if (precision == ov::element::i8 || precision == ov::element::u8) { + block_size += infer_precision.size() * 2; + } + } else { + if (precision == ov::element::i8 || precision == ov::element::u8) { + head_size += infer_precision.size() * 2 * group_num; + } + } + }); + } pass_config->set_callback([&](const std::shared_ptr node){ if (!config.get_enable_sdpa_optimization()) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 093a77b685870f..536b844417fab2 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -1,20 +1,23 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // -#include "test_utils.h" -#include "random_generator.hpp" - #include #include #include -#include #include +#include #include #include #include #include #include +#include + +#include "openvino/runtime/tensor.hpp" +#include "primitive_inst.h" +#include "random_generator.hpp" +#include "test_utils.h" using namespace cldnn; using namespace ov::intel_gpu; @@ -131,7 +134,8 @@ struct PagedAttentionManager { bool kv_cache_compression, ov::internal::CacheQuantMode key_cache_quant_mode, bool has_score_aggregation, - CacheRotationDescriptor rotation_config) + CacheRotationDescriptor rotation_config, + std::vector threshold) : num_heads(num_heads) , num_kv_heads(num_kv_heads) , k_head_size(k_head_size) @@ -149,15 +153,18 @@ struct PagedAttentionManager { // init subsequence_begins and block_indices_begins subsequence_begins.push_back(0); block_indices_begins.push_back(0); + for (int i = 0; i < static_cast(threshold.size()); i++) { + xattention_threshold.emplace_back(static_cast(threshold[i])); + } int max_len = 0; for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { const auto& subsequence_desc = subsequence_descs[i]; max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); - query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, k_head_size)); - key_data.push_back(generate_input_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); - value_data.push_back(generate_input_data(rg, num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); + query_data.push_back(generate_realistic_data(num_heads, subsequence_desc.num_tokens, k_head_size)); + key_data.push_back(generate_realistic_data(num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, k_head_size)); + value_data.push_back(generate_realistic_data(num_kv_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, v_head_size)); past_lens.push_back(subsequence_desc.past_len); int subsequence_start_pos = subsequence_begins[i]; @@ -224,6 +231,59 @@ struct PagedAttentionManager { return get_QKV_memory(value_data, num_kv_heads, v_head_size, true); } + memory::ptr get_key_cache_memory_cm() { + auto key_cache_dt = data_types::f16; + auto adjusted_head_size = k_head_size; + if (kv_cache_compression) { + key_cache_dt = data_types::i8; + adjusted_head_size += 4; + } + + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{num_blocks, num_kv_heads, block_size, adjusted_head_size}; + auto key_cache_layout = layout{key_cache_shape, key_cache_dt, format::bfyx}; + auto memory = test_engine.allocate_memory(key_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len + 1, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + input_token_offset * num_kv_heads * v_head_size + head_idx * v_head_size; + if (kv_cache_compression) { + auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); + auto quantized_data_ptr = quantized_data.data(); + + // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] + size_t output_block_offset = + (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; + size_t output_offset = output_block_offset + token_idx * v_head_size; + set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + + size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; + set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); + set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); + } else { + // shape: [num_blocks, num_kv_heads, block_size, v_head_size] + size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + + head_idx * block_size * v_head_size + token_idx * v_head_size; + + set_values(test_stream, memory, data_ptr, v_head_size, output_offset); + } + } + } + } + } + } + + return memory; + } + memory::ptr get_key_cache_memory() { auto key_cache_dt = data_types::f16; auto adjusted_head_size = k_head_size; @@ -244,7 +304,7 @@ struct PagedAttentionManager { for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { int past_len = subsequence_descs[i].past_len; if (past_len != 0) { - int blocks_num = ceil_div(past_len, block_size); + int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) @@ -338,7 +398,7 @@ struct PagedAttentionManager { for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { int past_len = subsequence_descs[i].past_len; if (past_len != 0) { - int blocks_num = ceil_div(past_len, block_size); + int blocks_num = ceil_div(past_len + 1, block_size); int start_block_idx = block_indices[block_indices_begins[i]]; for (int block_idx = 0; block_idx < blocks_num; block_idx++) { int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) @@ -553,6 +613,26 @@ struct PagedAttentionManager { return data; } + static std::vector generate_realistic_data(size_t num_heads, size_t tokens_num, size_t k_head_size) { + std::vector data(num_heads * tokens_num * k_head_size); + + std::mt19937 gen(1234); + std::normal_distribution dist(0.0f, 0.1f); + + for (size_t h = 0; h < num_heads; ++h) { + for (size_t t = 0; t < tokens_num; ++t) { + for (size_t d = 0; d < k_head_size; ++d) { + float val = dist(gen); + if (t > 0) + val = 0.8f * val + 0.2f * static_cast(data[h * tokens_num * k_head_size + (t - 1) * k_head_size + d]); + data[h * tokens_num * k_head_size + t * k_head_size + d] = static_cast(val); + } + } + } + + return data; + } + static std::vector generate_rotation_deltas_data(tests::random_generator& rg, size_t max_tokens_num, size_t rotated_blocks_num, size_t block_size, bool per_block) { const size_t total_elements_num = per_block ? rotated_blocks_num : rotated_blocks_num * block_size; @@ -627,7 +707,7 @@ struct PagedAttentionReference { , test_engine(pam.test_engine) , test_stream(pam.test_stream) {} - std::pair, std::vector> get_reference() { + std::pair, std::vector> get_reference(bool has_xattention, std::vector threshold) { std::vector ref_data_output; std::vector ref_scores_output; @@ -663,8 +743,9 @@ struct PagedAttentionReference { } auto window_size = pam.has_score_aggregation ? pam.score_aggregation[i] : 1; - - auto subsequence_ref_results = run_reference(pam.query_data[i], + double th = static_cast(threshold.size() == 1 ? threshold[0] : threshold[i]); + auto subsequence_ref_results = run_reference(has_xattention, + pam.query_data[i], key_data, pam.value_data[i], subsequence_desc.num_tokens, @@ -675,7 +756,8 @@ struct PagedAttentionReference { pam.v_head_size, window_size, pam.sliding_window_size, - pam.get_default_scale()); + pam.get_default_scale(), + th); // concatenate all subsequences into one vector ref_data_output.insert(ref_data_output.end(), @@ -690,27 +772,62 @@ struct PagedAttentionReference { } private: - std::pair, std::vector> - run_reference(const std::vector& query_data, - const std::vector& key_data, - const std::vector& value_data, - int num_queries, - int num_keys, - int num_heads, - int num_kv_heads, - int k_head_size, - int v_head_size, - int window_size, - int sliding_window_size, - float scale) { + std::pair, std::vector> run_reference(bool has_xattention, + const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + int num_queries, + int num_keys, + int num_heads, + int num_kv_heads, + int k_head_size, + int v_head_size, + int window_size, + int sliding_window_size, + float scale, + double threshold = 0.9, + size_t block_size = 128, + size_t stride = 16) { auto query_shape = ov::PartialShape{1, num_queries, num_heads, k_head_size}; auto key_shape = ov::PartialShape{1, num_keys, num_kv_heads, k_head_size}; auto value_shape = ov::PartialShape{1, num_keys, num_kv_heads, v_head_size}; - if (num_heads != num_kv_heads) { + if (num_heads != num_kv_heads && !has_xattention) { query_shape = ov::PartialShape{num_queries, num_kv_heads, (num_heads / num_kv_heads), k_head_size}; key_shape = ov::PartialShape{num_keys, num_kv_heads, 1, k_head_size}; value_shape = ov::PartialShape{num_keys, num_kv_heads, 1, v_head_size}; } + bool do_gqa_expand = false; + std::vector expanded_key_data; + std::vector expanded_value_data; + if (has_xattention) { + // Grouped Query Attention + do_gqa_expand = (num_heads != num_kv_heads); + if (do_gqa_expand) { + const int group_size = num_heads / num_kv_heads; + + expanded_key_data.resize(static_cast(num_keys) * static_cast(num_heads) * static_cast(k_head_size)); + expanded_value_data.resize(static_cast(num_keys) * static_cast(num_heads) * static_cast(v_head_size)); + + for (int key_idx = 0; key_idx < num_keys; ++key_idx) { + for (int h = 0; h < num_heads; ++h) { + const int src_kv_head = h / group_size; + size_t src_key_off = (static_cast(key_idx) * static_cast(num_kv_heads) + static_cast(src_kv_head)) * static_cast(k_head_size); + size_t dst_key_off = (static_cast(key_idx) * static_cast(num_heads) + static_cast(h)) * static_cast(k_head_size); + for (int d = 0; d < k_head_size; ++d) + expanded_key_data[dst_key_off + static_cast(d)] = key_data[src_key_off + static_cast(d)]; + + size_t src_val_off = (static_cast(key_idx) * static_cast(num_kv_heads) + static_cast(src_kv_head)) * static_cast(v_head_size); + size_t dst_val_off = (static_cast(key_idx) * static_cast(num_heads) + static_cast(h)) * static_cast(v_head_size); + for (int d = 0; d < v_head_size; ++d) + expanded_value_data[dst_val_off + static_cast(d)] = value_data[src_val_off + static_cast(d)]; + } + } + + key_shape = ov::PartialShape{1, num_keys, num_heads, k_head_size}; + value_shape = ov::PartialShape{1, num_keys, num_heads, v_head_size}; + num_kv_heads = num_heads; + } + } auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; auto key_layout = layout{key_shape, data_types::f16, format::bfyx}; @@ -718,20 +835,77 @@ struct PagedAttentionReference { auto scale_layout = cldnn::layout({1}, data_types::f16, format::bfyx); OPENVINO_ASSERT(query_layout.count() == query_data.size()); - OPENVINO_ASSERT(key_layout.count() == key_data.size()); - OPENVINO_ASSERT(value_layout.count() == value_data.size()); + if (do_gqa_expand) { + OPENVINO_ASSERT(key_layout.count() == expanded_key_data.size()); + OPENVINO_ASSERT(value_layout.count() == expanded_value_data.size()); + } else { + OPENVINO_ASSERT(key_layout.count() == key_data.size()); + OPENVINO_ASSERT(value_layout.count() == value_data.size()); + } auto query_mem = test_engine.allocate_memory(query_layout); auto key_mem = test_engine.allocate_memory(key_layout); auto value_mem = test_engine.allocate_memory(value_layout); - auto mask_mem = get_mask_mem(num_queries, num_keys, num_heads, sliding_window_size); auto scale_mem = test_engine.allocate_memory(scale_layout); set_values(query_mem, query_data); - set_values(key_mem, key_data); - set_values(value_mem, value_data); + if (do_gqa_expand) { + set_values(key_mem, expanded_key_data); + set_values(value_mem, expanded_value_data); + } else { + set_values(key_mem, key_data); + set_values(value_mem, value_data); + } set_values(scale_mem, {static_cast(scale)}); + ov::reference::XAttentionRetainedBlockIndicesForAllHeads retained_blocks; + if (num_queries >= static_cast(block_size) && has_xattention) { + auto reorder_qhk_to_hqd = [&](const std::vector& src, int outer_len, int num_heads, int head_dim) { + std::vector dst(num_heads * outer_len * head_dim); + for (int h = 0; h < num_heads; ++h) { + size_t dst_h_off = static_cast(h) * outer_len * head_dim; + for (int i = 0; i < outer_len; ++i) { + size_t src_off = static_cast(i) * num_heads * head_dim + static_cast(h) * head_dim; + std::copy_n(&src[src_off], head_dim, &dst[dst_h_off + static_cast(i) * head_dim]); + } + } + return dst; + }; + + const auto query_data_3d = reorder_qhk_to_hqd(query_data, num_queries, num_heads, k_head_size); + const auto key_data_3d = reorder_qhk_to_hqd(do_gqa_expand ? expanded_key_data : key_data, num_keys, num_heads, k_head_size); + const size_t padded_q = ((num_queries + block_size - 1) / block_size) * block_size; + const size_t padded_k = ((num_keys + block_size - 1) / block_size) * block_size; + + std::vector query_padded(num_heads * padded_q * k_head_size, 0.f); + std::vector key_padded(num_heads * padded_k * k_head_size, 0.f); + + for (int h = 0; h < num_heads; ++h) { + const auto* q_src = &query_data_3d[h * num_queries * k_head_size]; + const auto* k_src = &key_data_3d[h * num_keys * k_head_size]; + auto* q_dst = &query_padded[h * padded_q * k_head_size]; + auto* k_dst = &key_padded[h * padded_k * k_head_size]; + + std::transform(q_src, q_src + num_queries * k_head_size, q_dst, [](ov::float16 v) { + return static_cast(v); + }); + std::transform(k_src, k_src + num_keys * k_head_size, k_dst, [](ov::float16 v) { + return static_cast(v); + }); + } + ov::reference::XAttentionBlockSelector selector(threshold, block_size, stride); + retained_blocks = selector.select_blocks(query_padded.data(), + {static_cast(num_heads), padded_q, static_cast(k_head_size)}, + key_padded.data(), + {static_cast(num_heads), padded_k, static_cast(k_head_size)}); + } + auto mask_mem = get_mask_mem_combined_multi_head(num_queries, + num_keys, + num_heads, + num_kv_heads, + sliding_window_size, + retained_blocks, + static_cast(block_size)); topology topology; if (num_heads == num_kv_heads) { topology.add(input_layout("query", query_layout), @@ -827,69 +1001,102 @@ struct PagedAttentionReference { return output_data; } - memory::ptr get_mask_mem(int num_queries, int num_keys, int num_heads, int sliding_window_size) { - /* - * Two kinds of masks: - * - * Case 1 (N == K): - * num_queries = N - * num_keys = K = N - * k_head_size = H - * Q [N, H] * K[H, N] - * QK [N, N] - * 0 1 N - * 0 [ 0, MIN, .., MIN ] - * 1 [ 0, 0, .., MIN ] - * [ .., .., .., MIN ] - * N [ 0, 0, .., 0 ] - * - * Case 2 (N != K): - * num_queries = N - * num_keys = K - * k_head_size = H - * past_len = P = K - N + 1 - * Q [N, H] * K[H, K] - * QK [N, K] - * 0 1 2 P .. K - * 0 [ 0, 0, 0, MIN, MIN, MIN ] - * 1 [ 0, 0, 0, 0, MIN, MIN ] - * [ .., .., .., .., .., MIN ] - * N [ 0, 0, 0, 0, .., 0 ] - * - * Shapes: - * Q [1, num_heads, num_queries, k_head_size] - * K [1, num_heads, k_head_size, num_keys] - * Q*K [1, num_heads, num_queries, num_keys] - */ - - auto mask_shape = ov::PartialShape{ 1, 1, num_queries, num_keys }; + memory::ptr get_mask_mem_combined_multi_head(int num_queries, + int num_keys, + int num_heads, + int num_kv_heads, + int sliding_window_size, + const ov::reference::XAttentionRetainedBlockIndicesForAllHeads& retained_blocks, + int block_size) { + int heads_per_kv = num_heads / num_kv_heads; + + ov::PartialShape mask_shape; + if (retained_blocks.empty()) { + mask_shape = ov::PartialShape{1, 1, num_queries, num_keys}; + } else if (num_heads == num_kv_heads) { + mask_shape = ov::PartialShape{1, num_heads, num_queries, num_keys}; + } else { + mask_shape = ov::PartialShape{num_kv_heads, heads_per_kv, num_queries, num_keys}; + } + auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; auto mask_mem = test_engine.allocate_memory(mask_layout); - mem_lock mem_ptr(mask_mem, test_stream); - if (sliding_window_size == 0) { - int past_len = num_keys - num_queries + 1; - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest() - : ov::float16(0.f); + size_t total_elems = mask_layout.count(); + for (size_t i = 0; i < total_elems; ++i) + mem_ptr[i] = std::numeric_limits::lowest(); + if (retained_blocks.empty()) { + if (sliding_window_size == 0) { + int past_len = num_keys - num_queries + 1; + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest() : ov::float16(0.f); + } + } + } else { + int sliding_left = num_keys - num_queries - sliding_window_size + 1; + int past_len = num_keys - num_queries + 1; + + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + bool is_min; + if (num_queries == num_keys) { + is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; + } else { + is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; + } + + mem_ptr[i * num_keys + j] = is_min ? std::numeric_limits::lowest() : ov::float16(0.f); + } } } } else { - int sliding_left = num_keys - num_queries - sliding_window_size + 1; - int past_len = num_keys - num_queries + 1; - - for (int i = 0; i < num_queries; i++) { - for (int j = 0; j < num_keys; j++) { - bool is_min; - if (num_queries == num_keys) { - is_min = (j >= sliding_left + i) && (j <= i) ? 0 : 1; + for (int h = 0; h < num_heads; ++h) { + int kv_idx = (num_heads == num_kv_heads) ? 0 : (h / heads_per_kv); + int head_in_kv = (num_heads == num_kv_heads) ? h : (h % heads_per_kv); + + size_t head_offset = (static_cast(kv_idx) * heads_per_kv + static_cast(head_in_kv)) * static_cast(num_queries) * + static_cast(num_keys); + + for (int i = 0; i < num_queries; i++) { + int left_idx = 0; + int right_idx = 0; + + if (sliding_window_size == 0) { + int past_len = num_keys - num_queries + 1; + right_idx = past_len + i - 1; + left_idx = 0; } else { - is_min = (j >= sliding_left + i) && (j < past_len + i) ? 0 : 1; + int sliding_left = num_keys - num_queries - sliding_window_size + 1; + int past_len = num_keys - num_queries + 1; + if (num_queries == num_keys) { + left_idx = sliding_left + i; + right_idx = i; + } else { + left_idx = sliding_left + i; + right_idx = past_len + i - 1; + } } - mem_ptr[i * num_keys + j] = is_min ? std::numeric_limits::lowest() : ov::float16(0.f); + left_idx = std::max(0, left_idx); + right_idx = std::min(num_keys - 1, right_idx); + + for (const auto& [q_block_idx, k_block_idx] : retained_blocks[h]) { + int q_start = q_block_idx * block_size; + int q_end = std::min(q_start + block_size, num_queries); + int k_start = k_block_idx * block_size; + int k_end = std::min(k_start + block_size, num_keys); + + if (i < q_start || i >= q_end) + continue; + + for (int j = k_start; j < k_end; j++) { + if (j >= left_idx && j <= right_idx) { + mem_ptr[head_offset + i * num_keys + j] = ov::float16(0.f); + } + } + } } } } @@ -961,7 +1168,8 @@ struct PagedAttentionTest : public ::testing::TestWithParam { p.kv_cache_compression, p.key_cache_quant_mode, p.scores_mode == ScoresMode::SNAPKV, - p.rotation_config); + p.rotation_config, + p.threshold); if (p.kv_cache_compression) tolerance = 25e-3; @@ -969,8 +1177,13 @@ struct PagedAttentionTest : public ::testing::TestWithParam { auto query_mem = pam.get_query_memory(); auto key_mem = pam.get_key_memory(); auto value_mem = pam.get_value_memory(); - - auto key_cache_mem = pam.get_key_cache_memory(); + + memory::ptr key_cache_mem; + if (p.has_xattention) { + key_cache_mem = pam.get_key_cache_memory_cm(); + } else { + key_cache_mem = pam.get_key_cache_memory(); + } auto value_cache_mem = pam.get_value_cache_memory(); auto past_lens_mem = pam.get_past_lens_memory(); @@ -1112,6 +1325,9 @@ struct PagedAttentionTest : public ::testing::TestWithParam { pa_prim.has_score_aggregation = p.scores_mode == ScoresMode::SNAPKV; pa_prim.sliding_window = p.sliding_window_size; pa_prim.is_key_by_channel = (p.key_cache_quant_mode == ov::internal::CacheQuantMode::BY_CHANNEL); + if (p.has_xattention) { + pa_prim.has_xattention = true; + } topology topology; @@ -1190,8 +1406,12 @@ struct PagedAttentionTest : public ::testing::TestWithParam { if (p.scores_mode != ScoresMode::DISABLED) { output_scores_mem = outputs.at("output_scores").get_memory(); } - auto ref_data = PagedAttentionReference(pam).get_reference(); - compare(output_data_mem, output_scores_mem, ref_data); + auto ref_data = PagedAttentionReference(pam).get_reference(p.has_xattention, p.threshold); + if (p.has_xattention) { + compare_xattention(output_data_mem, output_scores_mem, ref_data); + } else { + compare(output_data_mem, output_scores_mem, ref_data); + } } void compare(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) { @@ -1211,6 +1431,39 @@ struct PagedAttentionTest : public ::testing::TestWithParam { } } } + + void compare_xattention(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) { + if (data_output_mem) { + ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); + mem_lock mem_ptr(data_output_mem, get_test_stream()); + int mismatch_count = 0; + for (size_t i = 0; i < data_output_mem->count(); i++) { + if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.first[i])) > tolerance) { + mismatch_count++; + } + } + EXPECT_LE(mismatch_count, int(data_output_mem->count() * 0.04)); + } + + if (scores_output_mem) { + ASSERT_EQ(scores_output_mem->count(), ref_data.second.size()); + mem_lock mem_ptr(scores_output_mem, get_test_stream()); + int mismatch_count = 0; + for (size_t i = 0; i < scores_output_mem->count(); i++) { + if (std::fabs(static_cast(mem_ptr[i]) - static_cast(ref_data.second[i])) > tolerance) { + mismatch_count++; + } + } + EXPECT_LE(mismatch_count, int(scores_output_mem->count() * 0.04)); + } + } + + static bool check_cm_available() { + auto& engine = get_test_engine(); + ExecutionConfig config = get_test_default_config(engine); + return cldnn::check_cm_jit_support(engine, config) && + (engine.get_device_info().arch == gpu_arch::xe2 || engine.get_device_info().arch == gpu_arch::xe3); + } }; struct paged_attention_test_params { @@ -1220,7 +1473,9 @@ struct paged_attention_test_params { int k_head_size; int v_head_size; int block_size; + std::vector threshold; int sliding_window_size; + bool has_xattention; bool kv_cache_compression; ov::internal::CacheQuantMode key_cache_quant_mode; bool dynamic_paddings; @@ -1236,6 +1491,15 @@ TEST_P(paged_attention_test, basic) { execute(p); } +class xattention_test : public PagedAttentionTest {}; +TEST_P(xattention_test, basic) { + if (!check_cm_available()) + GTEST_SKIP(); + auto p = GetParam(); + + execute(p); +} + const auto ENABLE_CACHE_COMPRESSION = true; const auto DISABLE_CACHE_COMPRESSION = false; const auto DISABLE_SCORES = ScoresMode::DISABLED; @@ -1250,125 +1514,142 @@ const auto ENABLE_FA_V2 = false; const auto DISABLE_FA_V2 = true; INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ - /* with scores output, use SnapKV */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES_SNAPKV, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores output */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 16, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 128, 96, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 48, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 96, 128, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 32, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 96, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 48, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 16, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 16, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 128, 96, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 48, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 96, 128, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 32, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 96, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 48, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 16, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* without scores output, dynamic input query paddings */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_block rotation */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 48, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN,STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 48, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_token rotation */ - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 128, 192, 16, 0, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 128, 192, 16, {100.0}, 0, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* without scores output, dynamic input query paddings, KV-cache compression */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token long + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{10, 0}, {81, 0}, {129, 0}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* with scores, per_block rotation, KV-cache compression */ - paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, DISABLE_FA_V2 }, // 2nd token /* with scores, per_token rotation, KV-cache compression */ - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token - paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 2, 64, 32, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token /* With sliding windows */ - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 6, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, 6, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, 20, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, 20, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, 8, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token - paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, 128, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token - paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, 4, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{5, 10}}, 2, 2, 64, 64, 16, 2, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{8, 8}}, 16, 16, 256, 256, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{34, 0}}, 2, 2, 32, 32, 16, 2, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 1st token - paged_attention_test_params{ {{1, 1008}}, 32, 32, 128, 128, 16, 6, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token - paged_attention_test_params{ {{6, 20}}, 2, 2, 128, 128, 16, 8, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{254, 10}}, 32, 8, 128, 128, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching, GQA - paged_attention_test_params{ {{84, 2}}, 32, 32, 128, 128, 16, 16, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 492}}, 32, 32, 32, 32, 16, 32, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 492}}, 16, 16, 64, 64, 16, 64, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 492}}, 8, 8, 128, 128, 16, 128, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{2, 30}}, 2, 2, 64, 64, 16, 0, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{232, 24}}, 2, 2, 512, 512, 16, 32, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 592}}, 32, 32, 128, 128, 16, 64, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 692}}, 32, 32, 128, 128, 16, 128, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1008, 792}}, 32, 32, 128, 128, 16, 256, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching - paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 2, 64, 64, 16, 10, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 6, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}}, 2, 2, 64, 64, 16, {100.0}, 6, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, {100.0}, 20, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{512, 0}}, 2, 2, 64, 32, 16, {100.0}, 20, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 2, 64, 64, 16, {100.0}, 8, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + 2nd token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 2, 48, 64, 16, {100.0}, 128, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION, ENABLE_FA_V2 }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 2, 64, 64, 16, {100.0}, 4, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{5, 10}}, 2, 2, 64, 64, 16, {100.0}, 2, false, DISABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{8, 8}}, 16, 16, 256, 256, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{34, 0}}, 2, 2, 32, 32, 16, {100.0}, 2, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1, 1008}}, 32, 32, 128, 128, 16, {100.0}, 6, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{6, 20}}, 2, 2, 128, 128, 16, {100.0}, 8, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{254, 10}}, 32, 8, 128, 128, 16, {100.0}, 10, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching, GQA + paged_attention_test_params{ {{84, 2}}, 32, 32, 128, 128, 16, {100.0}, 16, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 492}}, 32, 32, 32, 32, 16, {100.0}, 32, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 492}}, 16, 16, 64, 64, 16, {100.0}, 64, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 492}}, 8, 8, 128, 128, 16, {100.0}, 128, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{2, 30}}, 2, 2, 64, 64, 16, {100.0}, 0, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{232, 24}}, 2, 2, 512, 512, 16, {100.0}, 32, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_CHANNEL, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 592}}, 32, 32, 128, 128, 16, {100.0}, 64, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 692}}, 32, 32, 128, 128, 16, {100.0}, 128, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1008, 792}}, 32, 32, 128, 128, 16, {100.0}, 256, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: prefix caching + paged_attention_test_params{ {{1, 34}, {2, 20}, {10, 34}}, 2, 2, 64, 64, 16, {100.0}, 10, false, ENABLE_CACHE_COMPRESSION,ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION, DISABLE_FA_V2 }, // mixed: 2nd token + 1st token + part of 1st token +})); + +INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test, ::testing::ValuesIn(std::vector{ + paged_attention_test_params{ {{32, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{2048, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{32, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{2048, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + + paged_attention_test_params{ {{1, 31}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 32}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 1023}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 1024}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 127}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 128}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 129}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 32}}, 28, 28, 128, 128, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token }));