Skip to content

Commit 4b9f92a

Browse files
[GPU] Extend gemm to fuse unsqueeze layer (#23734)
### Details: - Follow up some comments from #23513 - Fuse `unsqueeze` layer into `gemm` layer for indirect gemm - before : [`kv_cache`] --> [`unsqueeze`] --> `gemm` - after : [`kv_cache`] --> `gemm` - Simplify fusion pass and logic as `unsqueeze` is fused together ### Tickets: - 136567 --------- Signed-off-by: Andrew Park <[email protected]>
1 parent 6e961cd commit 4b9f92a

File tree

14 files changed

+182
-430
lines changed

14 files changed

+182
-430
lines changed

src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,12 @@ class Gemm : public ov::op::v0::MatMul {
2626
const std::vector<int64_t>& order_c,
2727
const ov::element::Type output_type = ov::element::undefined);
2828

29-
Gemm(const ov::Output<Node>& A,
30-
const ov::Output<Node>& B,
31-
const std::vector<int32_t>& target_shape_a,
32-
const std::vector<int32_t>& target_shape_b,
33-
const std::vector<int64_t>& output_pattern_a,
34-
const std::vector<int64_t>& output_pattern_b,
35-
const std::vector<int64_t>& order_a,
36-
const std::vector<int64_t>& order_b,
37-
const std::vector<int64_t>& order_c,
38-
const ov::element::Type output_type = ov::element::undefined);
39-
4029
bool visit_attributes(ov::AttributeVisitor &visitor) override;
4130

4231
void validate_and_infer_types() override;
4332

4433
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
4534

46-
std::vector<int32_t> get_input0_broadcast_target_shape() const { return m_target_shape_a; }
47-
std::vector<int32_t> get_input1_broadcast_target_shape() const { return m_target_shape_b; }
48-
std::vector<int64_t> get_input0_reshape_pattern() const { return m_output_pattern_a; }
49-
std::vector<int64_t> get_input1_reshape_pattern() const { return m_output_pattern_b; }
5035
std::vector<int64_t> get_input0_transpose_order() const { return m_order_a; }
5136
std::vector<int64_t> get_input1_transpose_order() const { return m_order_b; }
5237
std::vector<int64_t> get_output_transpose_order() const { return m_order_c; }
@@ -59,10 +44,6 @@ class Gemm : public ov::op::v0::MatMul {
5944
}
6045

6146
protected:
62-
std::vector<int32_t> m_target_shape_a;
63-
std::vector<int32_t> m_target_shape_b;
64-
std::vector<int64_t> m_output_pattern_a;
65-
std::vector<int64_t> m_output_pattern_b;
6647
std::vector<int64_t> m_order_a;
6748
std::vector<int64_t> m_order_b;
6849
std::vector<int64_t> m_order_c;
@@ -71,10 +52,6 @@ class Gemm : public ov::op::v0::MatMul {
7152

7253
std::vector<ov::PartialShape> shape_infer(const Gemm* op,
7354
std::vector<ov::PartialShape> input_shapes,
74-
const std::vector<int32_t>& target_shape_a,
75-
const std::vector<int32_t>& target_shape_b,
76-
const std::vector<int64_t>& output_pattern_a,
77-
const std::vector<int64_t>& output_pattern_b,
7855
const std::vector<int64_t>& order_a,
7956
const std::vector<int64_t>& order_b,
8057
const std::vector<int64_t>& order_c);

src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ struct gemm : public primitive_base<gemm> {
5454
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
5555
transpose_input0(transpose_input0 ? 1 : 0),
5656
transpose_input1(transpose_input1 ? 1 : 0),
57-
input0_broadcast_target_shape({}),
58-
input1_broadcast_target_shape({}),
59-
input0_reshape_pattern({}),
60-
input1_reshape_pattern({}),
6157
alpha(alpha),
6258
beta(beta),
6359
input_rank(input_rank),
@@ -90,21 +86,13 @@ struct gemm : public primitive_base<gemm> {
9086
gemm(const primitive_id& id,
9187
const std::vector<input_info>& inputs,
9288
const data_types data_type,
93-
const std::vector<int32_t>& input0_broadcast_target_shape = {},
94-
const std::vector<int32_t>& input1_broadcast_target_shape = {},
95-
const std::vector<int64_t>& input0_reshape_pattern = {},
96-
const std::vector<int64_t>& input1_reshape_pattern = {},
9789
const std::vector<int64_t>& input0_transpose_order = {0, 1, 2, 3},
9890
const std::vector<int64_t>& input1_transpose_order = {0, 1, 2, 3},
9991
const std::vector<int64_t>& output_transpose_order = {},
10092
const float alpha = 1.0f,
10193
const float beta = 0.0f,
10294
const padding& output_padding = padding())
10395
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
104-
input0_broadcast_target_shape(input0_broadcast_target_shape),
105-
input1_broadcast_target_shape(input1_broadcast_target_shape),
106-
input0_reshape_pattern(input0_reshape_pattern),
107-
input1_reshape_pattern(input1_reshape_pattern),
10896
input0_transpose_order(input0_transpose_order),
10997
input1_transpose_order(input1_transpose_order),
11098
output_transpose_order(output_transpose_order),
@@ -133,10 +121,6 @@ struct gemm : public primitive_base<gemm> {
133121
const float beta = 0.0f,
134122
const padding& output_padding = padding())
135123
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
136-
input0_broadcast_target_shape({}),
137-
input1_broadcast_target_shape({}),
138-
input0_reshape_pattern({}),
139-
input1_reshape_pattern({}),
140124
input0_transpose_order(input0_transpose_order),
141125
input1_transpose_order(input1_transpose_order),
142126
output_transpose_order(output_transpose_order),
@@ -159,14 +143,6 @@ struct gemm : public primitive_base<gemm> {
159143
uint32_t transpose_input0 = 0;
160144
/// @brief Flag for transposing second input matrix
161145
uint32_t transpose_input1 = 0;
162-
/// @brief broadcasted target shape of input 0
163-
std::vector<int32_t> input0_broadcast_target_shape;
164-
/// @brief broadcasted target shape of input 1
165-
std::vector<int32_t> input1_broadcast_target_shape;
166-
/// @brief reshaped output pattern of input 0
167-
std::vector<int64_t> input0_reshape_pattern;
168-
/// @brief reshaped output pattern of input 1
169-
std::vector<int64_t> input1_reshape_pattern;
170146
/// @brief order of input 0
171147
std::vector<int64_t> input0_transpose_order;
172148
/// @brief order of input 1
@@ -193,10 +169,6 @@ struct gemm : public primitive_base<gemm> {
193169
seed = hash_combine(seed, transpose_input1);
194170
seed = hash_combine(seed, indirect_a);
195171
seed = hash_combine(seed, indirect_b);
196-
seed = hash_range(seed, input0_broadcast_target_shape.begin(), input0_broadcast_target_shape.end());
197-
seed = hash_range(seed, input1_broadcast_target_shape.begin(), input1_broadcast_target_shape.end());
198-
seed = hash_range(seed, input0_reshape_pattern.begin(), input0_reshape_pattern.end());
199-
seed = hash_range(seed, input1_reshape_pattern.begin(), input1_reshape_pattern.end());
200172
seed = hash_range(seed, input0_transpose_order.begin(), input0_transpose_order.end());
201173
seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end());
202174
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
@@ -225,10 +197,6 @@ struct gemm : public primitive_base<gemm> {
225197
primitive_base<gemm>::save(ob);
226198
ob << transpose_input0;
227199
ob << transpose_input1;
228-
ob << input0_broadcast_target_shape;
229-
ob << input1_broadcast_target_shape;
230-
ob << input0_reshape_pattern;
231-
ob << input1_reshape_pattern;
232200
ob << input0_transpose_order;
233201
ob << input1_transpose_order;
234202
ob << output_transpose_order;
@@ -246,10 +214,6 @@ struct gemm : public primitive_base<gemm> {
246214
primitive_base<gemm>::load(ib);
247215
ib >> transpose_input0;
248216
ib >> transpose_input1;
249-
ib >> input0_broadcast_target_shape;
250-
ib >> input1_broadcast_target_shape;
251-
ib >> input0_reshape_pattern;
252-
ib >> input1_reshape_pattern;
253217
ib >> input0_transpose_order;
254218
ib >> input1_transpose_order;
255219
ib >> output_transpose_order;

src/plugins/intel_gpu/src/graph/gemm.cpp

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,6 @@
1010

1111
#include "intel_gpu/op/gemm.hpp"
1212

13-
namespace {
14-
template <typename T, typename DT, typename = typename std::enable_if<std::is_convertible<DT, T>::value>::type>
15-
int find_index_from_vec(const std::vector<T>& vec, const DT value) {
16-
int idx = 0;
17-
for (auto v : vec) {
18-
if (v != static_cast<T>(value))
19-
break;
20-
idx += 1;
21-
}
22-
return idx;
23-
}
24-
} // namespace
2513
namespace cldnn {
2614
GPU_DEFINE_PRIMITIVE_TYPE_ID(gemm)
2715

@@ -139,10 +127,6 @@ std::vector<layout> gemm_inst::calc_output_layouts(gemm_node const& node, const
139127

140128
std::vector<ShapeType> output_shapes = ov::intel_gpu::op::shape_infer(&op,
141129
input_shapes,
142-
prim->input0_broadcast_target_shape,
143-
prim->input1_broadcast_target_shape,
144-
prim->input0_reshape_pattern,
145-
prim->input1_reshape_pattern,
146130
prim->input0_transpose_order,
147131
prim->input1_transpose_order,
148132
prim->output_transpose_order);
@@ -158,28 +142,6 @@ template std::vector<layout> gemm_inst::calc_output_layouts<ov::PartialShape>(ge
158142

159143
std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<const gemm> primitive,
160144
const std::vector<layout>& input_layouts) {
161-
auto get_reshaped_input_shape = [&](const ov::PartialShape& input_pshape,
162-
const std::vector<int32_t>& broadcast_target_shape,
163-
const std::vector<int64_t>& reshape_pattern) {
164-
ov::PartialShape reshaped_input_pshape;
165-
166-
if (broadcast_target_shape.size() > 0 && reshape_pattern.size() > 0) {
167-
std::vector<ov::Dimension> dims(input_pshape);
168-
int idx_recalc = find_index_from_vec(broadcast_target_shape, 1);
169-
int idx_target = find_index_from_vec(reshape_pattern, 0);
170-
if (dims[idx_recalc].is_static() && dims[idx_target].is_static()) {
171-
dims[idx_recalc] *= dims[idx_target];
172-
} else {
173-
dims[idx_recalc] = ov::Dimension::dynamic();
174-
}
175-
dims.erase(dims.begin() + idx_target);
176-
reshaped_input_pshape = ov::PartialShape(dims);
177-
} else {
178-
reshaped_input_pshape = input_pshape;
179-
}
180-
return reshaped_input_pshape;
181-
};
182-
183145
auto get_transposed_input_shape = [&](const ov::PartialShape& input_pshape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) {
184146
ov::PartialShape transposed_input_pshape;
185147

@@ -214,30 +176,20 @@ std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<con
214176
return transposed_input_pshape;
215177
};
216178

217-
auto reshaped_input0_pshape = get_reshaped_input_shape(input_layouts[0].get_partial_shape(),
218-
primitive->input0_broadcast_target_shape,
219-
primitive->input0_reshape_pattern);
220-
auto reshaped_input1_pshape = get_reshaped_input_shape(input_layouts[1].get_partial_shape(),
221-
primitive->input1_broadcast_target_shape,
222-
primitive->input1_reshape_pattern);
179+
auto input0_pshape = input_layouts[0].get_partial_shape();
180+
auto input1_pshape = input_layouts[1].get_partial_shape();
223181

224182
bool reordered = primitive->input_rank > 4 || primitive->weight_rank > 4;
225183
size_t output_rank = std::max(primitive->input_rank, primitive->weight_rank);
226184
size_t input_rank = reordered ? output_rank : primitive->input_rank;
227185
size_t weight_rank = reordered ? output_rank : primitive->weight_rank;
228186

229-
auto transposed_input0_pshape = get_transposed_input_shape(reshaped_input0_pshape, input_rank, output_rank, primitive->transpose_input0, true);
230-
auto transposed_input1_pshape = get_transposed_input_shape(reshaped_input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false);
187+
auto transposed_input0_pshape = get_transposed_input_shape(input0_pshape, input_rank, output_rank, primitive->transpose_input0, true);
188+
auto transposed_input1_pshape = get_transposed_input_shape(input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false);
231189

232190
std::vector<layout> layouts = input_layouts;
233191
layouts[0].set_partial_shape(transposed_input0_pshape);
234-
if (primitive->input0_broadcast_target_shape.size() > input_rank) {
235-
layouts[0].format = format::adjust_to_rank(layouts[0].format, input_rank);
236-
}
237192
layouts[1].set_partial_shape(transposed_input1_pshape);
238-
if (primitive->input1_broadcast_target_shape.size() > weight_rank) {
239-
layouts[1].format = format::adjust_to_rank(layouts[1].format, weight_rank);
240-
}
241193

242194
if (primitive->input_size() == 3) {
243195
auto bias_pshape = input_layouts[2].get_partial_shape();

src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5+
#include "intel_gpu/op/gemm.hpp"
6+
#include "intel_gpu/plugin/common_utils.hpp"
57
#include "intel_gpu/graph/kernel_impl_params.hpp"
68
#include "multi_stage_primitive.hpp"
79

@@ -173,14 +175,46 @@ struct gemm_impl : multi_stage_primitive<gemm> {
173175
params.beta = primitive->beta;
174176
params.transpose_input0 = primitive->transpose_input0;
175177
params.transpose_input1 = primitive->transpose_input1;
176-
params.input0_target_shape = primitive->input0_broadcast_target_shape;
177-
params.input1_target_shape = primitive->input1_broadcast_target_shape;
178-
params.input0_output_pattern = primitive->input0_reshape_pattern;
179-
params.input1_output_pattern = primitive->input0_reshape_pattern;
180178
params.input0_order = primitive->input0_transpose_order;
181179
params.input1_order = primitive->input1_transpose_order;
182180
params.output_order = primitive->output_transpose_order;
183181

182+
auto input0_pshape = impl_param.input_layouts[0].get_partial_shape();
183+
auto input1_pshape = impl_param.input_layouts[1].get_partial_shape();
184+
const auto is_broadcastable = input0_pshape.rank().is_static() &&
185+
input1_pshape.rank().is_static() &&
186+
input0_pshape.size() > 1 &&
187+
input1_pshape.size() > 1 &&
188+
(primitive->input_rank == primitive->weight_rank);
189+
if (is_broadcastable) {
190+
auto transpose_pshape = [](const ov::PartialShape pshape, const std::vector<int64_t>& order) {
191+
auto transposed_pshape = ov::PartialShape::dynamic(pshape.rank());
192+
for (size_t i = 0; i < order.size(); i++) {
193+
transposed_pshape[i] = pshape[order[i]];
194+
}
195+
return transposed_pshape;
196+
};
197+
size_t max_rank = input0_pshape.size();
198+
auto default_order = ov::intel_gpu::op::Gemm::default_order(max_rank);
199+
auto input0_trans_pshape = (primitive->input0_transpose_order != default_order) ?
200+
transpose_pshape(input0_pshape, primitive->input0_transpose_order) :
201+
input0_pshape;
202+
auto input1_trans_pshape = (primitive->input1_transpose_order != default_order) ?
203+
transpose_pshape(input1_pshape, primitive->input1_transpose_order) :
204+
input1_pshape;
205+
for (size_t i = 0; i < max_rank - 2; ++i) {
206+
if (input0_trans_pshape[i].is_static() && input1_trans_pshape[i].is_static()) {
207+
if (input1_trans_pshape[i].get_length() > input0_trans_pshape[i].get_length()) {
208+
params.input0_reshape_axes = primitive->input0_transpose_order[i];
209+
params.input0_broadcast_val = input1_trans_pshape[i].get_length() / input0_trans_pshape[i].get_length();
210+
} else if (input0_trans_pshape[i].get_length() > input1_trans_pshape[i].get_length()) {
211+
params.input1_reshape_axes = primitive->input1_transpose_order[i];
212+
params.input1_broadcast_val = input0_trans_pshape[i].get_length() / input1_trans_pshape[i].get_length();
213+
}
214+
}
215+
}
216+
}
217+
184218
params.indirect_input0 = primitive->indirect_a && indirect;
185219
params.indirect_input1 = primitive->indirect_b && indirect;
186220
if (indirect && (primitive->indirect_a || primitive->indirect_b)) {

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -215,41 +215,37 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
215215
jit.AddConstant(MakeJitConstant("BIAS_TERM", 1));
216216
}
217217

218-
auto get_broadcast_input_str = [](const std::vector<int32_t>& target_shape) {
219-
const size_t target_rank = target_shape.size();
218+
auto get_broadcast_input_str = [](const size_t input_rank, const int64_t axes, const int64_t val) {
220219
std::vector<std::string> dims;
221-
if (target_rank == 1) {
220+
if (input_rank == 1) {
222221
dims = {"x"};
223-
} else if (target_rank == 2) {
222+
} else if (input_rank == 2) {
224223
dims = {"y", "x"};
225-
} else if (target_rank == 3) {
224+
} else if (input_rank == 3) {
226225
dims = {"f", "y", "x"};
227-
} else if (target_rank == 4) {
226+
} else if (input_rank == 4) {
228227
dims = {"b", "f", "y", "x"};
229-
} else if (target_rank == 5) {
228+
} else if (input_rank == 5) {
230229
dims = {"b", "f", "z", "y", "x"};
231-
} else if (target_rank == 6) {
230+
} else if (input_rank == 6) {
232231
dims = {"b", "f", "w", "z", "y", "x"};
233232
}
234-
int pos = 0;
235-
for (auto ts : target_shape) {
236-
if (ts != 1)
237-
break;
238-
pos += 1;
239-
}
240-
std::string str = dims[pos] + " /= " + std::to_string(target_shape[pos]) + ";";
241-
return str;
233+
return dims[axes] + " /= " + std::to_string(val) + ";";
242234
};
243-
if (params.input0_target_shape.size() > 1) {
235+
if (params.input0_broadcast_val != 0) {
244236
jit.AddConstants({
245237
MakeJitConstant("BROADCAST_INPUT0", true),
246-
MakeJitConstant("DO_BROADCAST_INPUT0", get_broadcast_input_str(params.input0_target_shape)),
238+
MakeJitConstant("DO_BROADCAST_INPUT0", get_broadcast_input_str(params.inputs[0].GetDims().size(),
239+
params.input0_reshape_axes,
240+
params.input0_broadcast_val)),
247241
});
248242
}
249-
if (params.input1_target_shape.size() > 1) {
243+
if (params.input1_broadcast_val != 0) {
250244
jit.AddConstants({
251245
MakeJitConstant("BROADCAST_INPUT1", true),
252-
MakeJitConstant("DO_BROADCAST_INPUT1", get_broadcast_input_str(params.input1_target_shape)),
246+
MakeJitConstant("DO_BROADCAST_INPUT1", get_broadcast_input_str(params.inputs[1].GetDims().size(),
247+
params.input1_reshape_axes,
248+
params.input1_broadcast_val)),
253249
});
254250
}
255251

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ struct gemm_params : public base_params {
1919
float beta;
2020
uint32_t transpose_input0;
2121
uint32_t transpose_input1;
22-
std::vector<int32_t> input0_target_shape;
23-
std::vector<int32_t> input1_target_shape;
24-
std::vector<int64_t> input0_output_pattern;
25-
std::vector<int64_t> input1_output_pattern;
2622
std::vector<int64_t> input0_order;
2723
std::vector<int64_t> input1_order;
2824
std::vector<int64_t> output_order;
25+
int64_t input0_reshape_axes = 0;
26+
int64_t input1_reshape_axes = 0;
27+
int64_t input0_broadcast_val = 0;
28+
int64_t input1_broadcast_val = 0;
2929
DataTensor beam_table;
3030
bool indirect_input0 = false;
3131
bool indirect_input1 = false;

0 commit comments

Comments
 (0)