Skip to content

Commit f35beb7

Browse files
authored
[GPU] sdpa_micro for prefix caching (#31968)
### Details: - This PR extends `sdpa_micro` to support paged attention for better performance. - The `mixed` stage of paged attention will be handled by `sdpa_micro` instead of `pa_sdpa_opt`. - Additionally, this PR allows `sdpa_micro` to support `sliding window`. ### Tickets: - 169407, 170673, 172903, 173059
1 parent 48eadd7 commit f35beb7

File tree

4 files changed

+543
-266
lines changed

4 files changed

+543
-266
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
10871087
Stage::Ptr pa_scores_calc = make_stage<PagedAttentionGeneratorScoresCalculation>();
10881088
#ifdef ENABLE_ONEDNN_FOR_GPU
10891089
Stage::Ptr pa_sdpa_micro = make_stage<SDPAMicroGenerator>(true);
1090+
Stage::Ptr pa_sdpa_micro_mixed = make_stage<SDPAMicroGenerator>(false);
10901091
#endif
10911092

10921093
PagedAttentionOptImpl() : SDPAImplBase(PagedAttentionOpt::get_type_info_static()) {}
@@ -1099,6 +1100,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
10991100
const bool use_micro_sdpa = supports_micro_sdpa(params);
11001101
if (use_micro_sdpa) {
11011102
add_stage(pa_sdpa_micro, params);
1103+
add_stage(pa_sdpa_micro_mixed, params);
11021104
}
11031105
#endif
11041106

@@ -1169,8 +1171,9 @@ class PagedAttentionOptImpl : public SDPAImplBase {
11691171

11701172
size_t get_query_block_size(const PagedAttentionStage& stage, const bool use_micro_sdpa) const {
11711173
const auto default_block_size = 16;
1172-
if (use_micro_sdpa && stage == PagedAttentionStage::PREFILL)
1173-
return get_micro_tile_qsize(pa_sdpa_micro->kd);
1174+
if (use_micro_sdpa) {
1175+
return (stage == PagedAttentionStage::PREFILL) ? get_micro_tile_qsize(pa_sdpa_micro->kd) : get_micro_tile_qsize(pa_sdpa_micro_mixed->kd);
1176+
}
11741177
return default_block_size;
11751178
}
11761179
#else
@@ -1213,22 +1216,17 @@ class PagedAttentionOptImpl : public SDPAImplBase {
12131216
rt_params->paged_attention_snap_kv_tokens = 0;
12141217
}
12151218

1216-
if (rt_params->stage == PagedAttentionStage::PREFILL) {
12171219
#ifdef ENABLE_ONEDNN_FOR_GPU
1218-
// Determine if sdpa_micro can be used based on sliding_window and aliged_seq_len
1219-
bool support_sliding_window =
1220-
desc->sliding_window == 0 || (desc->sliding_window > 0 && rt_params->paged_attention_aligned_seq_len < desc->sliding_window);
1221-
rt_params->use_micro_sdpa = supports_micro_sdpa(params) && support_sliding_window;
1220+
rt_params->use_micro_sdpa = supports_micro_sdpa(params) && rt_params->stage != PagedAttentionStage::GENERATE;
12221221
#else
1223-
rt_params->use_micro_sdpa = false;
1222+
rt_params->use_micro_sdpa = false;
12241223
#endif
1225-
rt_params->query_block_size = get_query_block_size(rt_params->stage, rt_params->use_micro_sdpa);
1226-
} else {
1227-
rt_params->use_micro_sdpa = false;
1228-
}
1224+
rt_params->query_block_size = get_query_block_size(rt_params->stage, rt_params->use_micro_sdpa);
12291225

1230-
if (rt_params->stage == PagedAttentionStage::GENERATE) {
1226+
if (rt_params->stage == PagedAttentionStage::GENERATE && !rt_params->use_micro_sdpa) {
12311227
rt_params->use_gqa_kernel = can_use_gqa_kernel(params, PagedAttentionStage::GENERATE, rt_params->max_context_len);
1228+
} else {
1229+
rt_params->use_gqa_kernel = false;
12321230
}
12331231
return;
12341232
}
@@ -1274,9 +1272,14 @@ class PagedAttentionOptImpl : public SDPAImplBase {
12741272
if (rt_params->use_gqa_kernel) {
12751273
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token : pa_gqa_single_token)};
12761274
} else {
1277-
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token : pa_single_token)};
1275+
#ifdef ENABLE_ONEDNN_FOR_GPU
1276+
if (multi_tokens_mode && rt_params->use_micro_sdpa)
1277+
res_event = {execute_stage(res_event, instance, pa_sdpa_micro_mixed)};
1278+
else
1279+
#endif
1280+
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token : pa_single_token)};
12781281
}
1279-
if (num_of_partitions > 1) {
1282+
if (num_of_partitions > 1 && !rt_params->use_micro_sdpa) {
12801283
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token_finalization : pa_single_token_finalization)};
12811284
}
12821285
}
@@ -1364,11 +1367,9 @@ class PagedAttentionOptImpl : public SDPAImplBase {
13641367
const auto max_context_len = get_max_context_len(params);
13651368
num_of_partitions = ceil_div(max_context_len, partition_size);
13661369
}
1367-
bool can_use_micro_sdpa = stage == PagedAttentionStage::PREFILL;
1370+
bool can_use_micro_sdpa = false;
13681371
#ifdef ENABLE_ONEDNN_FOR_GPU
1369-
can_use_micro_sdpa &= has_stage(pa_sdpa_micro);
1370-
#else
1371-
can_use_micro_sdpa = false;
1372+
can_use_micro_sdpa = has_stage(pa_sdpa_micro) && stage != PagedAttentionStage::GENERATE;
13721373
#endif
13731374
GPU_DEBUG_TRACE_DETAIL << "get_internal_buffer_descs: stage = " << static_cast<size_t>(stage) << std::endl;
13741375
int64_t paged_attention_aligned_seq_len = -1;
@@ -1447,13 +1448,13 @@ class PagedAttentionOptImpl : public SDPAImplBase {
14471448
}
14481449

14491450
const auto multi_tokens_mode = stage == PagedAttentionStage::MIXED;
1450-
if (multi_tokens_mode) {
1451+
if (multi_tokens_mode && !can_use_micro_sdpa) {
14511452
internal_buffers.emplace_back(total_tokens, softmax_accumulator_type, lockable); // 9
14521453
}
14531454

14541455
#ifdef ENABLE_ONEDNN_FOR_GPU
14551456
if (can_use_micro_sdpa) {
1456-
const auto wg_tile_q = get_micro_tile_qsize(pa_sdpa_micro->kd);
1457+
const auto wg_tile_q = 8; // This is set as the minimum size of query block for sharing between sdpa_micro_prefill and mixed.
14571458
const auto target_seq_len = std::max(paged_attention_aligned_seq_len, static_cast<int64_t>(1));
14581459
const auto indexes_buf_size = ceil_div(target_seq_len, wg_tile_q) * 2;
14591460
internal_buffers.emplace_back(indexes_buf_size * 4, indexes_dt, lockable);
@@ -1552,7 +1553,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
15521553
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;
15531554
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> micro_sdpa_block_starts_and_gws_mapping_lock = nullptr;
15541555

1555-
if (stage == PagedAttentionStage::MIXED) {
1556+
if (stage == PagedAttentionStage::MIXED && !use_micro_sdpa) {
15561557
size_t sequential_gws_subseq_mapping_idx = 6;
15571558
if (has_score_aggregation) {
15581559
sequential_gws_subseq_mapping_idx = 9;
@@ -1567,7 +1568,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
15671568
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
15681569
}
15691570

1570-
if (stage == PagedAttentionStage::PREFILL && use_micro_sdpa) {
1571+
if (use_micro_sdpa) {
15711572
const auto memory_idx = 3; // intermediate_idx for micro kernel
15721573
auto memory = intermediates_memories[memory_idx];
15731574
micro_sdpa_block_starts_and_gws_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(memory, stream));
@@ -1619,7 +1620,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
16191620
}
16201621
}
16211622

1622-
if (stage == PagedAttentionStage::MIXED) {
1623+
if (stage == PagedAttentionStage::MIXED && !use_micro_sdpa) {
16231624
for (int32_t idx = seq_start; idx < seq_end; idx++) {
16241625
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
16251626
}

0 commit comments

Comments
 (0)