@@ -1087,6 +1087,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
10871087    Stage::Ptr pa_scores_calc = make_stage<PagedAttentionGeneratorScoresCalculation>();
10881088#ifdef  ENABLE_ONEDNN_FOR_GPU
10891089    Stage::Ptr pa_sdpa_micro = make_stage<SDPAMicroGenerator>(true );
1090+     Stage::Ptr pa_sdpa_micro_mixed = make_stage<SDPAMicroGenerator>(false );
10901091#endif 
10911092
10921093    PagedAttentionOptImpl () : SDPAImplBase(PagedAttentionOpt::get_type_info_static()) {}
@@ -1099,6 +1100,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
10991100        const  bool  use_micro_sdpa = supports_micro_sdpa (params);
11001101        if  (use_micro_sdpa) {
11011102            add_stage (pa_sdpa_micro, params);
1103+             add_stage (pa_sdpa_micro_mixed, params);
11021104        }
11031105#endif 
11041106
@@ -1169,8 +1171,9 @@ class PagedAttentionOptImpl : public SDPAImplBase {
11691171
11701172    size_t  get_query_block_size (const  PagedAttentionStage& stage, const  bool  use_micro_sdpa) const  {
11711173        const  auto  default_block_size = 16 ;
1172-         if  (use_micro_sdpa && stage == PagedAttentionStage::PREFILL)
1173-             return  get_micro_tile_qsize (pa_sdpa_micro->kd );
1174+         if  (use_micro_sdpa) {
1175+             return  (stage == PagedAttentionStage::PREFILL) ? get_micro_tile_qsize (pa_sdpa_micro->kd ) : get_micro_tile_qsize (pa_sdpa_micro_mixed->kd );
1176+         }
11741177        return  default_block_size;
11751178    }
11761179#else 
@@ -1213,22 +1216,17 @@ class PagedAttentionOptImpl : public SDPAImplBase {
12131216            rt_params->paged_attention_snap_kv_tokens  = 0 ;
12141217        }
12151218
1216-         if  (rt_params->stage  == PagedAttentionStage::PREFILL) {
12171219#ifdef  ENABLE_ONEDNN_FOR_GPU
1218-             //  Determine if sdpa_micro can be used based on sliding_window and aliged_seq_len
1219-             bool  support_sliding_window =
1220-                 desc->sliding_window  == 0  || (desc->sliding_window  > 0  && rt_params->paged_attention_aligned_seq_len  < desc->sliding_window );
1221-             rt_params->use_micro_sdpa  = supports_micro_sdpa (params) && support_sliding_window;
1220+         rt_params->use_micro_sdpa  = supports_micro_sdpa (params) && rt_params->stage  != PagedAttentionStage::GENERATE;
12221221#else 
1223-              rt_params->use_micro_sdpa  = false ;
1222+         rt_params->use_micro_sdpa  = false ;
12241223#endif 
1225-             rt_params->query_block_size  = get_query_block_size (rt_params->stage , rt_params->use_micro_sdpa );
1226-         } else  {
1227-             rt_params->use_micro_sdpa  = false ;
1228-         }
1224+         rt_params->query_block_size  = get_query_block_size (rt_params->stage , rt_params->use_micro_sdpa );
12291225
1230-         if  (rt_params->stage  == PagedAttentionStage::GENERATE) {
1226+         if  (rt_params->stage  == PagedAttentionStage::GENERATE && !rt_params-> use_micro_sdpa ) {
12311227            rt_params->use_gqa_kernel  = can_use_gqa_kernel (params, PagedAttentionStage::GENERATE, rt_params->max_context_len );
1228+         } else  {
1229+             rt_params->use_gqa_kernel  = false ;
12321230        }
12331231        return ;
12341232    }
@@ -1274,9 +1272,14 @@ class PagedAttentionOptImpl : public SDPAImplBase {
12741272            if  (rt_params->use_gqa_kernel ) {
12751273                res_event = {execute_stage (res_event, instance, multi_tokens_mode ? pa_multi_token : pa_gqa_single_token)};
12761274            } else  {
1277-                 res_event = {execute_stage (res_event, instance, multi_tokens_mode ? pa_multi_token : pa_single_token)};
1275+ #ifdef  ENABLE_ONEDNN_FOR_GPU
1276+                 if  (multi_tokens_mode && rt_params->use_micro_sdpa )
1277+                     res_event = {execute_stage (res_event, instance, pa_sdpa_micro_mixed)};
1278+                 else 
1279+ #endif 
1280+                     res_event = {execute_stage (res_event, instance, multi_tokens_mode ? pa_multi_token : pa_single_token)};
12781281            }
1279-             if  (num_of_partitions > 1 ) {
1282+             if  (num_of_partitions > 1  && !rt_params-> use_micro_sdpa ) {
12801283                res_event = {execute_stage (res_event, instance, multi_tokens_mode ? pa_multi_token_finalization : pa_single_token_finalization)};
12811284            }
12821285        }
@@ -1364,11 +1367,9 @@ class PagedAttentionOptImpl : public SDPAImplBase {
13641367            const  auto  max_context_len = get_max_context_len (params);
13651368            num_of_partitions = ceil_div (max_context_len, partition_size);
13661369        }
1367-         bool  can_use_micro_sdpa = stage == PagedAttentionStage::PREFILL ;
1370+         bool  can_use_micro_sdpa = false ;
13681371#ifdef  ENABLE_ONEDNN_FOR_GPU
1369-         can_use_micro_sdpa &= has_stage (pa_sdpa_micro);
1370- #else 
1371-         can_use_micro_sdpa = false ;
1372+         can_use_micro_sdpa = has_stage (pa_sdpa_micro) && stage != PagedAttentionStage::GENERATE;
13721373#endif 
13731374        GPU_DEBUG_TRACE_DETAIL << " get_internal_buffer_descs: stage = " static_cast <size_t >(stage) << std::endl;
13741375        int64_t  paged_attention_aligned_seq_len = -1 ;
@@ -1447,13 +1448,13 @@ class PagedAttentionOptImpl : public SDPAImplBase {
14471448        }
14481449
14491450        const  auto  multi_tokens_mode = stage == PagedAttentionStage::MIXED;
1450-         if  (multi_tokens_mode) {
1451+         if  (multi_tokens_mode && !can_use_micro_sdpa ) {
14511452            internal_buffers.emplace_back (total_tokens, softmax_accumulator_type, lockable);  //  9
14521453        }
14531454
14541455#ifdef  ENABLE_ONEDNN_FOR_GPU
14551456        if  (can_use_micro_sdpa) {
1456-             const  auto  wg_tile_q = get_micro_tile_qsize (pa_sdpa_micro-> kd ); 
1457+             const  auto  wg_tile_q = 8 ;   //  This is set as the minimum size of query block for sharing between sdpa_micro_prefill and mixed. 
14571458            const  auto  target_seq_len = std::max (paged_attention_aligned_seq_len, static_cast <int64_t >(1 ));
14581459            const  auto  indexes_buf_size = ceil_div (target_seq_len, wg_tile_q) * 2 ;
14591460            internal_buffers.emplace_back (indexes_buf_size * 4 , indexes_dt, lockable);
@@ -1552,7 +1553,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
15521553        std::unique_ptr<mem_lock<int32_t , mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr ;
15531554        std::unique_ptr<mem_lock<int32_t , mem_lock_type::write>> micro_sdpa_block_starts_and_gws_mapping_lock = nullptr ;
15541555
1555-         if  (stage == PagedAttentionStage::MIXED) {
1556+         if  (stage == PagedAttentionStage::MIXED && !use_micro_sdpa ) {
15561557            size_t  sequential_gws_subseq_mapping_idx = 6 ;
15571558            if  (has_score_aggregation) {
15581559                sequential_gws_subseq_mapping_idx = 9 ;
@@ -1567,7 +1568,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
15671568            sequential_gws_subseq_mapping_lock.reset (new  mem_lock<int32_t , mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
15681569        }
15691570
1570-         if  (stage == PagedAttentionStage::PREFILL &&  use_micro_sdpa) {
1571+         if  (use_micro_sdpa) {
15711572            const  auto  memory_idx = 3 ;  //  intermediate_idx for micro kernel
15721573            auto  memory = intermediates_memories[memory_idx];
15731574            micro_sdpa_block_starts_and_gws_mapping_lock.reset (new  mem_lock<int32_t , mem_lock_type::write>(memory, stream));
@@ -1619,7 +1620,7 @@ class PagedAttentionOptImpl : public SDPAImplBase {
16191620                }
16201621            }
16211622
1622-             if  (stage == PagedAttentionStage::MIXED) {
1623+             if  (stage == PagedAttentionStage::MIXED && !use_micro_sdpa ) {
16231624                for  (int32_t  idx = seq_start; idx < seq_end; idx++) {
16241625                    sequential_gws_subseq_mapping_lock->operator [](idx) = static_cast <int32_t >(i);
16251626                }
0 commit comments