@@ -429,8 +429,10 @@ Arguments PagedAttentionGeneratorMultiToken::get_arguments_desc(const kernel_imp
429429 args.push_back ({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SUBSEQUENCE_BEGINS}); // subsequence_begins
430430
431431 const size_t block_size = get_xattn_block_size (params);
432- if (block_size > 1 )
432+ if (block_size > 1 ) {
433433 args.push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4 }); // sparse_block_mask
434+ args.push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5 }); // sparse_block_mask_wg
435+ }
434436
435437 args.push_back ({ArgumentDescriptor::Types::OUTPUT, 0 });
436438
@@ -944,4 +946,74 @@ DispatchDataFunc XAttentionEstimateFindBlock::get_dispatch_data_func() const {
944946 }};
945947}
946948
949+ // -----------------------------------------------------------------------------------------------------------------
950+ // XAttention Estimate post_proc generator
951+ // -----------------------------------------------------------------------------------------------------------------
952+ JitConstants XAttentionEstimatePostProc::get_jit_constants (const kernel_impl_params& params) const {
953+ auto jit = XAttentionEstimateGeneratorBase::get_jit_constants (params);
954+
955+ jit.make (" MERGED_Q_NUM" , 2 ); // TODO
956+
957+ return jit;
958+ }
959+
960+ Arguments XAttentionEstimatePostProc::get_arguments_desc (const kernel_impl_params& params) const {
961+ Arguments args;
962+
963+ // inputs
964+ args.push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4 }); // block_mask
965+
966+ // outputs
967+ args.push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5 }); // block_mask_merged
968+
969+ // scalar
970+ args.push_back ({ArgumentDescriptor::Types::SCALAR, 0 }); // q_stride_pad
971+ args.push_back ({ArgumentDescriptor::Types::SCALAR, 1 }); // q_block_pad
972+ args.push_back ({ArgumentDescriptor::Types::SCALAR, 2 }); // k_block_pad
973+
974+ return args;
975+ }
976+
977+ DispatchDataFunc XAttentionEstimatePostProc::get_dispatch_data_func () const {
978+ return DispatchDataFunc{[&](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) {
979+ assert (!params.is_dynamic ());
980+ auto & wgs = kd.params .workGroups ;
981+
982+ const auto desc = params.typed_desc <paged_attention>();
983+
984+ assert (rt_params != nullptr );
985+
986+ const size_t block_size = get_xattn_block_size (params);
987+ const size_t heads_num = desc->heads_num ;
988+
989+ auto out_shape = params.output_layouts [0 ].get_shape ();
990+ const size_t kv_len = get_max_context_len (params) / STRIDE * STRIDE;
991+ const size_t q_len = out_shape[0 ];
992+ const uint32_t M = static_cast <uint32_t >(q_len / STRIDE); // # will slient drop the tails which is less than `stride`
993+ const uint32_t N = static_cast <uint32_t >(kv_len / STRIDE);
994+ const size_t q_stride_pad = round_up_to (M, BLOCK_WG_M);
995+ const size_t N_kq_groups = ceil_div (N, BLOCK_WG_N);
996+
997+ const uint32_t sum_per_token_in_block = static_cast <uint32_t >(block_size / STRIDE);
998+ const uint32_t k_block_in_group = static_cast <uint32_t >(BLOCK_WG_N / sum_per_token_in_block);
999+ const uint32_t k_block_pad = k_block_in_group * N_kq_groups;
1000+ const uint32_t q_block_pad = ceil_div (q_len, block_size);
1001+
1002+ const uint32_t MERGED_Q_NUM = 2 ; // TODO
1003+ const uint32_t q_block_pad_merged = ceil_div (q_block_pad, MERGED_Q_NUM);
1004+
1005+ wgs.global = {q_block_pad_merged, heads_num, 1 };
1006+ wgs.local = {1 , 1 , 1 };
1007+
1008+ auto & scalars = kd.params .scalars ;
1009+ std::vector<size_t > scaler_value = {q_stride_pad, q_block_pad, k_block_pad};
1010+ scalars.resize (scaler_value.size ());
1011+
1012+ for (size_t i = 0 ; i < scaler_value.size (); ++i) {
1013+ scalars[i].t = ScalarDescriptor::Types::UINT32;
1014+ scalars[i].v .u32 = static_cast <uint32_t >(scaler_value[i]);
1015+ }
1016+ }};
1017+ }
1018+
9471019} // namespace ov::intel_gpu::cm
0 commit comments