Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
Stage::Ptr pa_scores_calc = make_stage<PagedAttentionGeneratorScoresCalculation>();
#ifdef ENABLE_ONEDNN_FOR_GPU
Stage::Ptr pa_sdpa_micro = make_stage<SDPAMicroGenerator>(true);
Stage::Ptr pa_sdpa_micro_mixed = make_stage<SDPAMicroGenerator>(false);
#endif

PagedAttentionOptImpl() : SDPAImplBase(PagedAttentionOpt::get_type_info_static()) {}
Expand All @@ -1099,6 +1100,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
const bool use_micro_sdpa = supports_micro_sdpa(params);
if (use_micro_sdpa) {
add_stage(pa_sdpa_micro, params);
add_stage(pa_sdpa_micro_mixed, params);
}
#endif

Expand Down Expand Up @@ -1169,8 +1171,9 @@ class PagedAttentionOptImpl : public SDPAImplBase {

size_t get_query_block_size(const PagedAttentionStage& stage, const bool use_micro_sdpa) const {
const auto default_block_size = 16;
if (use_micro_sdpa && stage == PagedAttentionStage::PREFILL)
return get_micro_tile_qsize(pa_sdpa_micro->kd);
if (use_micro_sdpa) {
return (stage == PagedAttentionStage::PREFILL) ? get_micro_tile_qsize(pa_sdpa_micro->kd) : get_micro_tile_qsize(pa_sdpa_micro_mixed->kd);
}
return default_block_size;
}
#else
Expand Down Expand Up @@ -1213,22 +1216,17 @@ class PagedAttentionOptImpl : public SDPAImplBase {
rt_params->paged_attention_snap_kv_tokens = 0;
}

if (rt_params->stage == PagedAttentionStage::PREFILL) {
#ifdef ENABLE_ONEDNN_FOR_GPU
// Determine if sdpa_micro can be used based on sliding_window and aliged_seq_len
bool support_sliding_window =
desc->sliding_window == 0 || (desc->sliding_window > 0 && rt_params->paged_attention_aligned_seq_len < desc->sliding_window);
rt_params->use_micro_sdpa = supports_micro_sdpa(params) && support_sliding_window;
rt_params->use_micro_sdpa = supports_micro_sdpa(params) && rt_params->stage != PagedAttentionStage::GENERATE;
#else
rt_params->use_micro_sdpa = false;
rt_params->use_micro_sdpa = false;
#endif
rt_params->query_block_size = get_query_block_size(rt_params->stage, rt_params->use_micro_sdpa);
} else {
rt_params->use_micro_sdpa = false;
}
rt_params->query_block_size = get_query_block_size(rt_params->stage, rt_params->use_micro_sdpa);

if (rt_params->stage == PagedAttentionStage::GENERATE) {
if (rt_params->stage == PagedAttentionStage::GENERATE && !rt_params->use_micro_sdpa) {
rt_params->use_gqa_kernel = can_use_gqa_kernel(params, PagedAttentionStage::GENERATE, rt_params->max_context_len);
} else {
rt_params->use_gqa_kernel = false;
}
return;
}
Expand Down Expand Up @@ -1274,9 +1272,14 @@ class PagedAttentionOptImpl : public SDPAImplBase {
if (rt_params->use_gqa_kernel) {
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token : pa_gqa_single_token)};
} else {
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token : pa_single_token)};
#ifdef ENABLE_ONEDNN_FOR_GPU
if (multi_tokens_mode && rt_params->use_micro_sdpa)
res_event = {execute_stage(res_event, instance, pa_sdpa_micro_mixed)};
else
#endif
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token : pa_single_token)};
}
if (num_of_partitions > 1) {
if (num_of_partitions > 1 && !rt_params->use_micro_sdpa) {
res_event = {execute_stage(res_event, instance, multi_tokens_mode ? pa_multi_token_finalization : pa_single_token_finalization)};
}
}
Expand Down Expand Up @@ -1364,11 +1367,9 @@ class PagedAttentionOptImpl : public SDPAImplBase {
const auto max_context_len = get_max_context_len(params);
num_of_partitions = ceil_div(max_context_len, partition_size);
}
bool can_use_micro_sdpa = stage == PagedAttentionStage::PREFILL;
bool can_use_micro_sdpa = false;
#ifdef ENABLE_ONEDNN_FOR_GPU
can_use_micro_sdpa &= has_stage(pa_sdpa_micro);
#else
can_use_micro_sdpa = false;
can_use_micro_sdpa = has_stage(pa_sdpa_micro) && stage != PagedAttentionStage::GENERATE;
#endif
GPU_DEBUG_TRACE_DETAIL << "get_internal_buffer_descs: stage = " << static_cast<size_t>(stage) << std::endl;
int64_t paged_attention_aligned_seq_len = -1;
Expand Down Expand Up @@ -1447,13 +1448,13 @@ class PagedAttentionOptImpl : public SDPAImplBase {
}

const auto multi_tokens_mode = stage == PagedAttentionStage::MIXED;
if (multi_tokens_mode) {
if (multi_tokens_mode && !can_use_micro_sdpa) {
internal_buffers.emplace_back(total_tokens, softmax_accumulator_type, lockable); // 9
}

#ifdef ENABLE_ONEDNN_FOR_GPU
if (can_use_micro_sdpa) {
const auto wg_tile_q = get_micro_tile_qsize(pa_sdpa_micro->kd);
const auto wg_tile_q = 8; // This is set as the minimum size of query block for sharing between sdpa_micro_prefill and mixed.
const auto target_seq_len = std::max(paged_attention_aligned_seq_len, static_cast<int64_t>(1));
const auto indexes_buf_size = ceil_div(target_seq_len, wg_tile_q) * 2;
internal_buffers.emplace_back(indexes_buf_size * 4, indexes_dt, lockable);
Expand Down Expand Up @@ -1552,7 +1553,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> micro_sdpa_block_starts_and_gws_mapping_lock = nullptr;

if (stage == PagedAttentionStage::MIXED) {
if (stage == PagedAttentionStage::MIXED && !use_micro_sdpa) {
size_t sequential_gws_subseq_mapping_idx = 6;
if (has_score_aggregation) {
sequential_gws_subseq_mapping_idx = 9;
Expand All @@ -1567,7 +1568,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
}

if (stage == PagedAttentionStage::PREFILL && use_micro_sdpa) {
if (use_micro_sdpa) {
const auto memory_idx = 3; // intermediate_idx for micro kernel
auto memory = intermediates_memories[memory_idx];
micro_sdpa_block_starts_and_gws_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(memory, stream));
Expand Down Expand Up @@ -1619,7 +1620,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
}
}

if (stage == PagedAttentionStage::MIXED) {
if (stage == PagedAttentionStage::MIXED && !use_micro_sdpa) {
for (int32_t idx = seq_start; idx < seq_end; idx++) {
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
}
Expand Down
Loading
Loading