Skip to content

Commit 34825fe

Browse files
authored
[GPU] Multi kernel gemm for TILE_K_NOT_DIVISIBLE and TILE_N_NOT_DIVISIBLE (#24007)
### Details: - In the case of shape-agnostic gemm kernel, `TILE_K_NOT_DIVISIBLE` and `TILE_N_NOT_DIVISIBLE` can be expressed as a conditional branches rather than constant values, in which case performance could be degraded. - This PR changes a gemm impl to have four OpenCL kernels with `TILE_K_NOT_DIVISIBLE` and `TILE_N_NOT_DIVISIBLE` of 0 and 1. Just before enqueuing a gemm kernel, it checks the value of `TILE_K_NOT_DIVISIBLE` and `TILE_N_NOT_DIVISIBLE` on the host side, and chooses one of them. ### Tickets: - 134699
1 parent 4b9f92a commit 34825fe

File tree

3 files changed

+121
-47
lines changed

3 files changed

+121
-47
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,11 @@ KERNEL(gemm_tiled_opt)(
270270
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
271271
#else // B_VEC_SIZE == 1
272272
#if TILE_N_NOT_DIVISIBLE
273-
if (TILE_N_NOT_DIVISIBLE_CALC) {
274-
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
275-
b_tile[b_load_id][b_elem] = b_ptr[sglid + SIMD_WIDTH * b_elem];
276-
}
277-
} else {
278-
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
279-
}
273+
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
274+
b_tile[b_load_id][b_elem] = b_ptr[sglid + SIMD_WIDTH * b_elem];
275+
}
280276
#else // TILE_N_NOT_DIVISIBLE
281-
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
277+
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
282278
#endif // TILE_N_NOT_DIVISIBLE
283279
#endif // B_VEC_SIZE == 1
284280
b_ptr += input1_offset;
@@ -381,11 +377,7 @@ KERNEL(gemm_tiled_opt)(
381377

382378
// Loading A tile and tile C calculation
383379
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING && TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
384-
#if TILE_K_NOT_DIVISIBLE
385-
A_FLOATN a_read = TILE_K_NOT_DIVISIBLE_CALC ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0);
386-
#else
387-
A_FLOATN a_read = BLOCK_READ_A(a_ptr, 0);
388-
#endif
380+
A_FLOATN a_read = TILE_K_NOT_DIVISIBLE ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0);
389381
#endif
390382
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
391383
#if TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
@@ -431,11 +423,7 @@ KERNEL(gemm_tiled_opt)(
431423
}
432424
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING
433425
// Read A for next dot_id
434-
#if TILE_K_NOT_DIVISIBLE
435-
a_read = (dot_id + 1 < tile_m_iterations) ? TILE_K_NOT_DIVISIBLE_CALC ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0) : 0;
436-
#else
437-
a_read = (dot_id + 1 < tile_m_iterations) ? BLOCK_READ_A(a_ptr, 0) : 0;
438-
#endif
426+
a_read = (dot_id + 1 < tile_m_iterations) ? TILE_K_NOT_DIVISIBLE ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0) : 0;
439427
#endif
440428
#elif TRANSPOSE_INPUT0 == TRANSPOSE_OTHER // TRANSPOSE_INPUT0
441429
#if INDIRECT_INPUT0
@@ -482,9 +470,8 @@ KERNEL(gemm_tiled_opt)(
482470
// Full tile calculation end
483471

484472
// Handle leftovers for K
485-
#if TILE_K_NOT_DIVISIBLE
486-
#if IS_DYNAMIC
487-
if (TILE_K_NOT_DIVISIBLE_CALC) {
473+
#if IS_DYNAMIC
474+
if (TILE_K_NOT_DIVISIBLE) {
488475
// Loading leftovers of the matrix B
489476
#if TRANSPOSE_INPUT1 != TRANSPOSE_Y_LAST
490477
B_FLOATN b_tile[TILE_K];
@@ -520,15 +507,11 @@ KERNEL(gemm_tiled_opt)(
520507
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
521508
#else // B_VEC_SIZE == 1
522509
#if TILE_N_NOT_DIVISIBLE
523-
if (TILE_N_NOT_DIVISIBLE_CALC) {
524510
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
525511
b_tile[b_load_id][b_elem] = b_ptr[sglid + SIMD_WIDTH * b_elem];
526512
}
527-
} else {
528-
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
529-
}
530513
#else
531-
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
514+
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
532515
#endif // TILE_N_NOT_DIVISIBLE
533516
#endif // B_VEC_SIZE == 1
534517
b_ptr += input1_offset;
@@ -698,8 +681,7 @@ KERNEL(gemm_tiled_opt)(
698681
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_id)), b_tile[simd_id], c_tile[dot_id]);
699682
}
700683
} // Loading leftovers of the matrix A and tile C calculation end
701-
#endif // IS_DYNAMIC
702-
#endif // TILE_K_NOT_DIVISIBLE
684+
#endif // IS_DYNAMIC
703685

704686
#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
705687
#if IS_DYNAMIC

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,6 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
156156
const std::string not_divisible_k = "(" + leftover_k + "!=0)";
157157
const std::string full_iteration_k = "(" + k_size + "/" + std::to_string(tuning_data.tile_k_size) + ")";
158158

159-
bool tile_k_may_have_leftover = false;
160-
if (k_size.find("shape_info") == std::string::npos) {
161-
tile_k_may_have_leftover = ((std::stoi(k_size) % tuning_data.tile_k_size) != 0);
162-
} else {
163-
tile_k_may_have_leftover = true;
164-
}
165-
166-
bool tile_n_may_have_leftover = false;
167-
if (n_size.find("shape_info") == std::string::npos) {
168-
tile_n_may_have_leftover = ((std::stoi(n_size) % tuning_data.tile_n_size) != 0);
169-
} else {
170-
tile_n_may_have_leftover = true;
171-
}
172-
173159
jit.AddConstants({
174160
MakeJitConstant("M", m_size),
175161
MakeJitConstant("K", k_size),
@@ -182,10 +168,8 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
182168
MakeJitConstant("TILE_N", tuning_data.tile_n_size),
183169
MakeJitConstant("K_FULL_ITERATIONS", full_iteration_k),
184170
MakeJitConstant("TILE_M_NOT_DIVISIBLE", not_divisible_m),
185-
MakeJitConstant("TILE_K_NOT_DIVISIBLE", tile_k_may_have_leftover),
186-
MakeJitConstant("TILE_K_NOT_DIVISIBLE_CALC", not_divisible_k),
187-
MakeJitConstant("TILE_N_NOT_DIVISIBLE", tile_n_may_have_leftover),
188-
MakeJitConstant("TILE_N_NOT_DIVISIBLE_CALC", not_divisible_n),
171+
MakeJitConstant("TILE_K_NOT_DIVISIBLE", not_divisible_k),
172+
MakeJitConstant("TILE_N_NOT_DIVISIBLE", not_divisible_n),
189173
MakeJitConstant("TILE_M_LEFTOVER", leftover_m),
190174
MakeJitConstant("TILE_K_LEFTOVER", leftover_k),
191175
MakeJitConstant("TILE_N_LEFTOVER", leftover_n),
@@ -356,7 +340,54 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
356340
}
357341

358342
KernelsData GemmKernelTiledOpt::GetKernelsData(const Params& params) const {
359-
return GetCommonKernelsData(params);
343+
if (!Validate(params)) {
344+
return KernelsData();
345+
}
346+
347+
const auto& prim_params = static_cast<const gemm_params&>(params);
348+
size_t num_kernels = params.is_shape_agnostic ? 4 : 1;
349+
auto dispatchData = SetDefault(prim_params);
350+
KernelData k_data = KernelData::Default<gemm_params>(params, num_kernels);
351+
GetUpdateDispatchDataFunc(k_data);
352+
auto cldnn_jit = GetJitConstants(prim_params);
353+
for (size_t i = 0; i < num_kernels; i++) {
354+
if (params.is_shape_agnostic) {
355+
cldnn_jit.RemoveConstant("TILE_K_NOT_DIVISIBLE");
356+
cldnn_jit.RemoveConstant("TILE_N_NOT_DIVISIBLE");
357+
if (i == 0) {
358+
cldnn_jit.AddConstant(MakeJitConstant("TILE_K_NOT_DIVISIBLE", "0"));
359+
cldnn_jit.AddConstant(MakeJitConstant("TILE_N_NOT_DIVISIBLE", "0"));
360+
} else if (i == 1) {
361+
cldnn_jit.AddConstant(MakeJitConstant("TILE_K_NOT_DIVISIBLE", "0"));
362+
cldnn_jit.AddConstant(MakeJitConstant("TILE_N_NOT_DIVISIBLE", "1"));
363+
} else if (i == 2) {
364+
cldnn_jit.AddConstant(MakeJitConstant("TILE_K_NOT_DIVISIBLE", "1"));
365+
cldnn_jit.AddConstant(MakeJitConstant("TILE_N_NOT_DIVISIBLE", "0"));
366+
} else if (i == 3) {
367+
cldnn_jit.AddConstant(MakeJitConstant("TILE_K_NOT_DIVISIBLE", "1"));
368+
cldnn_jit.AddConstant(MakeJitConstant("TILE_N_NOT_DIVISIBLE", "1"));
369+
}
370+
}
371+
auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params, i);
372+
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
373+
374+
auto& kernel = k_data.kernels[i];
375+
FillCLKernelData(kernel,
376+
dispatchData,
377+
params.engineInfo,
378+
kernelName,
379+
jit,
380+
entry_point,
381+
EXE_MODE_DEFAULT,
382+
false,
383+
false,
384+
(uint32_t)prim_params.inputs.size(),
385+
GetFusedPrimitiveInputsCount(params),
386+
1,
387+
prim_params.is_shape_agnostic);
388+
}
389+
390+
return {k_data};
360391
}
361392

362393
KernelsPriority GemmKernelTiledOpt::GetKernelsPriority(const Params& params) const {
@@ -404,4 +435,64 @@ bool GemmKernelTiledOpt::Validate(const Params& params) const {
404435

405436
return true;
406437
}
438+
439+
void GemmKernelTiledOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
440+
if (kd.kernels.size() == 1) {
441+
Parent::GetUpdateDispatchDataFunc(kd);
442+
} else {
443+
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
444+
const auto& prim_params = static_cast<const gemm_params&>(params);
445+
446+
auto getTensorValue = [](const DataTensor& t, const int64_t dim_idx) -> size_t {
447+
switch (dim_idx) {
448+
case 1:
449+
return t.Feature().v;
450+
case 2:
451+
return t.U().v;
452+
case 3:
453+
return t.V().v;
454+
case 4:
455+
return t.W().v;
456+
case 5:
457+
return t.Z().v;
458+
case 6:
459+
return t.Y().v;
460+
case 7:
461+
return t.X().v;
462+
default:
463+
return t.Batch().v;
464+
}
465+
};
466+
467+
GemmTuningData tuning_data = SetTuningParams(prim_params);
468+
auto input0_dims = ConvTo8dims(prim_params.input0_order);
469+
auto input1_dims = ConvTo8dims(prim_params.input1_order);
470+
auto k_size = getTensorValue(prim_params.inputs[0], input0_dims[7]);
471+
auto n_size = getTensorValue(prim_params.inputs[1], input1_dims[7]);
472+
bool not_divisible_k = ((k_size % tuning_data.tile_k_size) != 0);
473+
bool not_divisible_n = ((n_size % tuning_data.tile_n_size) != 0);
474+
size_t execute_kernel_idx = 0;
475+
if (not_divisible_k == false && not_divisible_n == false) {
476+
execute_kernel_idx = 0;
477+
} else if (not_divisible_k == false && not_divisible_n == true) {
478+
execute_kernel_idx = 1;
479+
} else if (not_divisible_k == true && not_divisible_n == false) {
480+
execute_kernel_idx = 2;
481+
} else if (not_divisible_k == true && not_divisible_n == true) {
482+
execute_kernel_idx = 3;
483+
}
484+
485+
auto dispatchData = SetDefault(prim_params);
486+
for (size_t i = 0; i < kd.kernels.size(); i++) {
487+
kd.kernels[i].params.workGroups.global = dispatchData.gws;
488+
kd.kernels[i].params.workGroups.local = dispatchData.lws;
489+
if (execute_kernel_idx == i) {
490+
kd.kernels[i].skip_execution = KernelData::SkipKernelExecution(prim_params);
491+
} else {
492+
kd.kernels[i].skip_execution = true;
493+
}
494+
}
495+
};
496+
}
497+
}
407498
} // namespace kernel_selector

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ class GemmKernelTiledOpt : public GemmKernelBase {
3636
JitConstants GetJitConstants(const gemm_params& params) const override;
3737
GemmTuningData SetTuningParams(const gemm_params& params) const;
3838
DeviceFeaturesKey get_required_device_features_key(const Params& params) const override;
39+
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
3940
};
4041
} // namespace kernel_selector

0 commit comments

Comments
 (0)