Skip to content

Commit d9a7806

Browse files
integrate xattn_post_proc kernel and FP16 kernel works. TODOto verify u8 kvcache.
1 parent d8e66d7 commit d9a7806

File tree

4 files changed

+150
-1
lines changed

4 files changed

+150
-1
lines changed

src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM {
3737
Stage::Ptr pa_multi_token = make_stage<PagedAttentionGeneratorMultiToken>();
3838
Stage::Ptr xattn_estimate_gemmqk = make_stage<XAttentionEstimateGEMMQK>();
3939
Stage::Ptr xattn_estimate_find_block = make_stage<XAttentionEstimateFindBlock>();
40+
Stage::Ptr xattn_estimate_post_proc = make_stage<XAttentionEstimatePostProc>();
4041

4142
PagedAttentionCmImpl(): PrimitiveImplCM(PagedAttentionImplementationManager::get_type_info_static()) {
4243
m_rt_params = std::make_unique<PagedAttentionRuntimeParams>();
@@ -53,6 +54,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM {
5354
if (xattn_block_size > 1) {
5455
add_stage(xattn_estimate_gemmqk, params);
5556
add_stage(xattn_estimate_find_block, params);
57+
add_stage(xattn_estimate_post_proc, params);
5658
}
5759
}
5860

@@ -124,6 +126,7 @@ class PagedAttentionCmImpl : public PrimitiveImplCM {
124126
pa_id++;
125127
}
126128
#endif
129+
res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)};
127130
}
128131
res_event = {execute_stage(res_event, instance, pa_multi_token)};
129132
} else if (rt_params->stage == PagedAttentionStage::GENERATE) {
@@ -202,6 +205,11 @@ class PagedAttentionCmImpl : public PrimitiveImplCM {
202205

203206
auto count_elements_mask = static_cast<int64_t>(desc->heads_num * q_block_pad * k_block_pad);
204207
internal_buffers.emplace_back(count_elements_mask, ov::element::boolean); // 4: sparse_block_mask
208+
209+
const uint32_t MERGED_Q_NUM = 2; // TODO
210+
const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM);
211+
auto count_elements_mask_merged = static_cast<int64_t>(desc->heads_num * q_block_pad_merged * k_block_pad);
212+
internal_buffers.emplace_back(count_elements_mask_merged, ov::element::boolean); // 5: sparse_block_mask_wg
205213
}
206214
}
207215

src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,10 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp
429429
args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins
430430

431431
const size_t block_size = get_xattn_block_size(params);
432-
if (block_size > 1)
432+
if (block_size > 1) {
433433
args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // sparse_block_mask
434+
args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // sparse_block_mask_wg
435+
}
434436

435437
args.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
436438

@@ -944,4 +946,74 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const {
944946
}};
945947
}
946948

949+
//-----------------------------------------------------------------------------------------------------------------
950+
// XAttention Estimate post_proc generator
951+
//-----------------------------------------------------------------------------------------------------------------
952+
JitConstants XAttentionEstimatePostProc::get_jit_constants(const kernel_impl_params& params) const {
953+
auto jit = XAttentionEstimateGeneratorBase::get_jit_constants(params);
954+
955+
jit.make("MERGED_Q_NUM", 2); // TODO
956+
957+
return jit;
958+
}
959+
960+
Arguments XAttentionEstimatePostProc::get_arguments_desc(const kernel_impl_params& params) const {
961+
Arguments args;
962+
963+
// inputs
964+
args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); // block_mask
965+
966+
// outputs
967+
args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); // block_mask_merged
968+
969+
// scalar
970+
args.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // q_stride_pad
971+
args.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // q_block_pad
972+
args.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // k_block_pad
973+
974+
return args;
975+
}
976+
977+
DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func() const {
978+
return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) {
979+
assert(!params.is_dynamic());
980+
auto& wgs = kd.params.workGroups;
981+
982+
const auto desc = params.typed_desc<paged_attention>();
983+
984+
assert(rt_params != nullptr);
985+
986+
const size_t block_size = get_xattn_block_size(params);
987+
const size_t heads_num = desc->heads_num;
988+
989+
auto out_shape = params.output_layouts[0].get_shape();
990+
const size_t kv_len = get_max_context_len(params) / STRIDE * STRIDE;
991+
const size_t q_len = out_shape[0];
992+
const uint32_t M = static_cast<uint32_t>(q_len / STRIDE); //# will slient drop the tails which is less than `stride`
993+
const uint32_t N = static_cast<uint32_t>(kv_len / STRIDE);
994+
const size_t q_stride_pad = round_up_to(M, BLOCK_WG_M);
995+
const size_t N_kq_groups = ceil_div(N, BLOCK_WG_N);
996+
997+
const uint32_t sum_per_token_in_block = static_cast<uint32_t>(block_size / STRIDE);
998+
const uint32_t k_block_in_group = static_cast<uint32_t>(BLOCK_WG_N / sum_per_token_in_block);
999+
const uint32_t k_block_pad = k_block_in_group * N_kq_groups;
1000+
const uint32_t q_block_pad = ceil_div(q_len, block_size);
1001+
1002+
const uint32_t MERGED_Q_NUM = 2; // TODO
1003+
const uint32_t q_block_pad_merged = ceil_div(q_block_pad, MERGED_Q_NUM);
1004+
1005+
wgs.global = {q_block_pad_merged, heads_num, 1};
1006+
wgs.local = {1, 1, 1};
1007+
1008+
auto& scalars = kd.params.scalars;
1009+
std::vector<size_t> scaler_value = {q_stride_pad, q_block_pad, k_block_pad};
1010+
scalars.resize(scaler_value.size());
1011+
1012+
for (size_t i = 0; i < scaler_value.size(); ++i) {
1013+
scalars[i].t = ScalarDescriptor::Types::UINT32;
1014+
scalars[i].v.u32 = static_cast<uint32_t>(scaler_value[i]);
1015+
}
1016+
}};
1017+
}
1018+
9471019
} // namespace ov::intel_gpu::cm

src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,12 @@ class XAttentionEstimateFindBlock : public XAttentionEstimateGeneratorBase {
134134
[[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override;
135135
};
136136

137+
class XAttentionEstimatePostProc : public XAttentionEstimateGeneratorBase {
138+
public:
139+
XAttentionEstimatePostProc() : XAttentionEstimateGeneratorBase("xattn_post_proc") {}
140+
[[nodiscard]] JitConstants get_jit_constants(const kernel_impl_params& params) const override;
141+
[[nodiscard]] Arguments get_arguments_desc(const kernel_impl_params& params) const override;
142+
[[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override;
143+
};
144+
137145
} // namespace ov::intel_gpu::cm
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2022-2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
namespace KERNEL_NAME {
18+
#include "estimate.hpp"
19+
20+
// NOTE: q_stride_pad / TOKEN_IN_BLOCK <= q_block_pad, case for q_stride_pad / TOKEN_IN_BLOCK < q_block_pad:
21+
// query = 256*16+1, then
22+
// q_stride_pad = 256
23+
// q_stride_pad / TOKEN_IN_BLOCK = 32
24+
// q_block_pad = div_up(256*16+1, 128) = 33
25+
// _GENX_MAIN_ void post_proc_mask(
26+
extern "C" _GENX_MAIN_ void KERNEL_NAME(svmptr_t block_mask ATTR, svmptr_t merged_block_mask ATTR, uint q_stride_pad, uint q_block_pad, uint k_block_pad) {
27+
// block_mask: [b, hq, q_block_pad, k_block_pad]
28+
// merged_block_mask: [b, hq, q_block_pad/MERGED_Q_NUM, k_block_pad]
29+
// global: [q_block_pad/MERGED_Q_NUM, hq, b]
30+
const int TOKEN_IN_BLOCK = BLOCK_SIZE / STRIDE;
31+
const int TOKEN_SHARE_MAX = BLOCK_SHARE_MAX / TOKEN_IN_BLOCK;
32+
uint m_mereged = cm_group_id(0);
33+
uint hq = cm_group_id(1);
34+
uint b = cm_group_id(2);
35+
block_mask += (b * HQ + hq) * q_block_pad * k_block_pad;
36+
merged_block_mask += (b * HQ + hq) * cm_group_count(0) * k_block_pad;
37+
merged_block_mask += m_mereged * k_block_pad;
38+
block_mask += m_mereged * MERGED_Q_NUM * k_block_pad;
39+
vector<uchar, 32> one = 1;
40+
// q is not inside mask, aka q=1~15 which is less than param `stride`
41+
//for (int i = 0; i < MERGED_Q_NUM; i++) {
42+
// auto q_stride_cur = m_mereged * MERGED_Q_NUM + i;
43+
// if (q_stride_cur >= q_stride_pad / TOKEN_IN_BLOCK && q_stride_cur < q_block_pad) {
44+
// for (int j = 0; j < k_block_pad; j += 32) {
45+
// cm_ptr_store<int, 32 / 4>((int*)block_mask, j + i * k_block_pad, one.format<int>());
46+
// }
47+
// }
48+
//}
49+
for (int j = 0; j < k_block_pad; j += 32) {
50+
vector<uchar, 32> new_mask = cm_ptr_load<int, 8>((int*)block_mask, j).format<uchar>();
51+
for (int i = 1; i < MERGED_Q_NUM; i++) {
52+
if (m_mereged * MERGED_Q_NUM + i < q_stride_pad / TOKEN_IN_BLOCK) {
53+
vector<uchar, 32> cur_mask = cm_ptr_load<int, 8>((int*)block_mask, j + i * k_block_pad).format<uchar>();
54+
new_mask &= cur_mask;
55+
}
56+
}
57+
cm_ptr_store<int, 32 / 4>((int*)merged_block_mask, j, new_mask.format<int>());
58+
}
59+
}
60+
61+
} // NAMESPACE

0 commit comments

Comments
 (0)