@@ -156,20 +156,6 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
156156 const std::string not_divisible_k = " (" + leftover_k + " !=0)" ;
157157 const std::string full_iteration_k = " (" + k_size + " /" + std::to_string (tuning_data.tile_k_size ) + " )" ;
158158
159- bool tile_k_may_have_leftover = false ;
160- if (k_size.find (" shape_info" ) == std::string::npos) {
161- tile_k_may_have_leftover = ((std::stoi (k_size) % tuning_data.tile_k_size ) != 0 );
162- } else {
163- tile_k_may_have_leftover = true ;
164- }
165-
166- bool tile_n_may_have_leftover = false ;
167- if (n_size.find (" shape_info" ) == std::string::npos) {
168- tile_n_may_have_leftover = ((std::stoi (n_size) % tuning_data.tile_n_size ) != 0 );
169- } else {
170- tile_n_may_have_leftover = true ;
171- }
172-
173159 jit.AddConstants ({
174160 MakeJitConstant (" M" , m_size),
175161 MakeJitConstant (" K" , k_size),
@@ -182,10 +168,8 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
182168 MakeJitConstant (" TILE_N" , tuning_data.tile_n_size ),
183169 MakeJitConstant (" K_FULL_ITERATIONS" , full_iteration_k),
184170 MakeJitConstant (" TILE_M_NOT_DIVISIBLE" , not_divisible_m),
185- MakeJitConstant (" TILE_K_NOT_DIVISIBLE" , tile_k_may_have_leftover),
186- MakeJitConstant (" TILE_K_NOT_DIVISIBLE_CALC" , not_divisible_k),
187- MakeJitConstant (" TILE_N_NOT_DIVISIBLE" , tile_n_may_have_leftover),
188- MakeJitConstant (" TILE_N_NOT_DIVISIBLE_CALC" , not_divisible_n),
171+ MakeJitConstant (" TILE_K_NOT_DIVISIBLE" , not_divisible_k),
172+ MakeJitConstant (" TILE_N_NOT_DIVISIBLE" , not_divisible_n),
189173 MakeJitConstant (" TILE_M_LEFTOVER" , leftover_m),
190174 MakeJitConstant (" TILE_K_LEFTOVER" , leftover_k),
191175 MakeJitConstant (" TILE_N_LEFTOVER" , leftover_n),
@@ -356,7 +340,54 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
356340}
357341
358342KernelsData GemmKernelTiledOpt::GetKernelsData (const Params& params) const {
359- return GetCommonKernelsData (params);
343+ if (!Validate (params)) {
344+ return KernelsData ();
345+ }
346+
347+ const auto & prim_params = static_cast <const gemm_params&>(params);
348+ size_t num_kernels = params.is_shape_agnostic ? 4 : 1 ;
349+ auto dispatchData = SetDefault (prim_params);
350+ KernelData k_data = KernelData::Default<gemm_params>(params, num_kernels);
351+ GetUpdateDispatchDataFunc (k_data);
352+ auto cldnn_jit = GetJitConstants (prim_params);
353+ for (size_t i = 0 ; i < num_kernels; i++) {
354+ if (params.is_shape_agnostic ) {
355+ cldnn_jit.RemoveConstant (" TILE_K_NOT_DIVISIBLE" );
356+ cldnn_jit.RemoveConstant (" TILE_N_NOT_DIVISIBLE" );
357+ if (i == 0 ) {
358+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_K_NOT_DIVISIBLE" , " 0" ));
359+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_N_NOT_DIVISIBLE" , " 0" ));
360+ } else if (i == 1 ) {
361+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_K_NOT_DIVISIBLE" , " 0" ));
362+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_N_NOT_DIVISIBLE" , " 1" ));
363+ } else if (i == 2 ) {
364+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_K_NOT_DIVISIBLE" , " 1" ));
365+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_N_NOT_DIVISIBLE" , " 0" ));
366+ } else if (i == 3 ) {
367+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_K_NOT_DIVISIBLE" , " 1" ));
368+ cldnn_jit.AddConstant (MakeJitConstant (" TILE_N_NOT_DIVISIBLE" , " 1" ));
369+ }
370+ }
371+ auto entry_point = GetEntryPoint (kernelName, prim_params.layerID , params, i);
372+ auto jit = CreateJit (kernelName, cldnn_jit, entry_point);
373+
374+ auto & kernel = k_data.kernels [i];
375+ FillCLKernelData (kernel,
376+ dispatchData,
377+ params.engineInfo ,
378+ kernelName,
379+ jit,
380+ entry_point,
381+ EXE_MODE_DEFAULT,
382+ false ,
383+ false ,
384+ (uint32_t )prim_params.inputs .size (),
385+ GetFusedPrimitiveInputsCount (params),
386+ 1 ,
387+ prim_params.is_shape_agnostic );
388+ }
389+
390+ return {k_data};
360391}
361392
362393KernelsPriority GemmKernelTiledOpt::GetKernelsPriority (const Params& params) const {
@@ -404,4 +435,64 @@ bool GemmKernelTiledOpt::Validate(const Params& params) const {
404435
405436 return true ;
406437}
438+
439+ void GemmKernelTiledOpt::GetUpdateDispatchDataFunc (KernelData& kd) const {
440+ if (kd.kernels .size () == 1 ) {
441+ Parent::GetUpdateDispatchDataFunc (kd);
442+ } else {
443+ kd.update_dispatch_data_func = [this ](const Params& params, KernelData& kd) {
444+ const auto & prim_params = static_cast <const gemm_params&>(params);
445+
446+ auto getTensorValue = [](const DataTensor& t, const int64_t dim_idx) -> size_t {
447+ switch (dim_idx) {
448+ case 1 :
449+ return t.Feature ().v ;
450+ case 2 :
451+ return t.U ().v ;
452+ case 3 :
453+ return t.V ().v ;
454+ case 4 :
455+ return t.W ().v ;
456+ case 5 :
457+ return t.Z ().v ;
458+ case 6 :
459+ return t.Y ().v ;
460+ case 7 :
461+ return t.X ().v ;
462+ default :
463+ return t.Batch ().v ;
464+ }
465+ };
466+
467+ GemmTuningData tuning_data = SetTuningParams (prim_params);
468+ auto input0_dims = ConvTo8dims (prim_params.input0_order );
469+ auto input1_dims = ConvTo8dims (prim_params.input1_order );
470+ auto k_size = getTensorValue (prim_params.inputs [0 ], input0_dims[7 ]);
471+ auto n_size = getTensorValue (prim_params.inputs [1 ], input1_dims[7 ]);
472+ bool not_divisible_k = ((k_size % tuning_data.tile_k_size ) != 0 );
473+ bool not_divisible_n = ((n_size % tuning_data.tile_n_size ) != 0 );
474+ size_t execute_kernel_idx = 0 ;
475+ if (not_divisible_k == false && not_divisible_n == false ) {
476+ execute_kernel_idx = 0 ;
477+ } else if (not_divisible_k == false && not_divisible_n == true ) {
478+ execute_kernel_idx = 1 ;
479+ } else if (not_divisible_k == true && not_divisible_n == false ) {
480+ execute_kernel_idx = 2 ;
481+ } else if (not_divisible_k == true && not_divisible_n == true ) {
482+ execute_kernel_idx = 3 ;
483+ }
484+
485+ auto dispatchData = SetDefault (prim_params);
486+ for (size_t i = 0 ; i < kd.kernels .size (); i++) {
487+ kd.kernels [i].params .workGroups .global = dispatchData.gws ;
488+ kd.kernels [i].params .workGroups .local = dispatchData.lws ;
489+ if (execute_kernel_idx == i) {
490+ kd.kernels [i].skip_execution = KernelData::SkipKernelExecution (prim_params);
491+ } else {
492+ kd.kernels [i].skip_execution = true ;
493+ }
494+ }
495+ };
496+ }
497+ }
407498} // namespace kernel_selector
0 commit comments