Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions onnxruntime/core/providers/webgpu/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
std::transform(local_strides.begin(), local_strides.end(), std::back_inserter(strides), transform_dim);
std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim);
auto rank = input_shape.NumDimensions();
const InlinedVector<size_t> perm = {2, 3, 1, 0};
/* In order to use transpose-shared instead transpose-naive when perm is {2, 3, 1, 0}, we transpose
the kernel in two steps, perm{2, 3, 1, 0} = perm_1{0, 2, 3, 1} + perm_2{1, 2, 3, 0}.
For example, if kernel_shape is [3, 4, 5, 6], transposed_kernel_shape is [5, 6, 4, 3].
step 1: kernel_shape[3, 4, 5, 6] + perm1{0, 2, 3, 1} = transposed_kernel1_shape[3, 5, 6, 4]
step 2: transposed_kernel1_shape[3, 5, 6, 4] + perm2{1, 2, 3, 0} = transposed_kernel2_shape[5, 6, 4, 3]
*/
const InlinedVector<size_t> perm_1 = {0, 2, 3, 1};
const InlinedVector<size_t> perm_2 = {1, 2, 3, 0};
if (rank > 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d and Conv2d are supported.");
} else if (rank == 4) {
Expand Down Expand Up @@ -104,11 +111,13 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
std::vector<uint32_t> updated_pads{pad0, pad1};
if (conv_attrs_.group > 1) {
Tensor transposed_kernel;
Tensor transposed_kernel_1;
Tensor transposed_kernel_2;
if (is_channels_last) {
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
inputs[1] = &transposed_kernel;
modified_input_output_shapes[1] = transposed_kernel.Shape();
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel_1, perm_1));
ORT_RETURN_IF_ERROR(TransposeKernel(context, &transposed_kernel_1, transposed_kernel_1.Shape(), &transposed_kernel_2, perm_2));
inputs[1] = &transposed_kernel_2;
modified_input_output_shapes[1] = transposed_kernel_2.Shape();
}
auto output_channels_per_group = output_channels / conv_attrs_.group;
auto components = static_cast<int>(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1);
Expand Down Expand Up @@ -138,7 +147,8 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context

const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
Tensor transposed_kernel;
Tensor transposed_kernel_1;
Tensor transposed_kernel_2;
TensorShape input_reshape;
TensorShape kernel_reshape;
TensorShape matmul_output_shape;
Expand All @@ -147,8 +157,9 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
if (is_channels_last) {
// Transpose weights

ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
inputs[1] = &transposed_kernel;
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel_1, perm_1));
ORT_RETURN_IF_ERROR(TransposeKernel(context, &transposed_kernel_1, transposed_kernel_1.Shape(), &transposed_kernel_2, perm_2));
inputs[1] = &transposed_kernel_2;
if (same_size) {
const auto shared_dim = input_height * input_width * input_channels;
input_reshape = TensorShape({1, batch, shared_dim});
Expand All @@ -160,7 +171,7 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels});
}
matmul_inputs.push_back(input);
matmul_inputs.push_back(&transposed_kernel);
matmul_inputs.push_back(&transposed_kernel_2);
matmul_input_reshapes.push_back(input_reshape);
matmul_input_reshapes.push_back(kernel_reshape);
} else {
Expand Down Expand Up @@ -205,14 +216,16 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
}
}
// Transpose weights
Tensor transposed_kernel;
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
Tensor transposed_kernel_1;
Tensor transposed_kernel_2;
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel_1, perm_1));
ORT_RETURN_IF_ERROR(TransposeKernel(context, &transposed_kernel_1, transposed_kernel_1.Shape(), &transposed_kernel_2, perm_2));
auto dim_a_outer = static_cast<uint32_t>(is_channels_last ? output_height * output_width : output_channels);
auto dim_b_outer = static_cast<uint32_t>(is_channels_last ? output_channels : output_height * output_width);
auto dim_inner = static_cast<uint32_t>(kernel_height * kernel_width * input_channels);
inputs[1] = &transposed_kernel;
TensorShape transposed_kernel_shape = transposed_kernel.Shape();
modified_input_output_shapes[1] = transposed_kernel.Shape();
inputs[1] = &transposed_kernel_2;
TensorShape transposed_kernel_shape = transposed_kernel_2.Shape();
modified_input_output_shapes[1] = transposed_kernel_2.Shape();
Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes);
return context.RunProgram(conv2d_mm_program);
}
Expand Down
53 changes: 35 additions & 18 deletions onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,35 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);

if (use_shared_) {
std::string input_str = std::string("a_indices_t(") + (map_first_channels_first_ ? "batch, " : "") + "input_row, input_col)";
std::string output_str = std::string("output_indices_t(") + (map_first_channels_first_ ? "batch, " : "") + "output_row, output_col)";
std::string input_cond_str = map_first_channels_first_ ? "input_row < uniforms.a_shape[1] && input_col < uniforms.a_shape[2]"
: "input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]";
std::string output_cond_str = map_first_channels_first_ ? "output_row < uniforms.output_shape[1] && output_col < uniforms.output_shape[2]"
: "output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]";
shader.AdditionalImplementation() << "var<workgroup> tile : array<array<output_value_t, tile_size + 1>, tile_size>;\n";
shader.MainFunctionBody() << " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n"
if (map_first_channels_first_) {
shader.MainFunctionBody() << " let batch = workgroup_id.z;\n"
" let stride_x = (uniforms.output_shape[2] - 1) / tile_size + 1;\n"
" let stride_y = (uniforms.output_shape[1] - 1) / tile_size + 1;\n"
" let workgroup_id_xy = workgroup_idx % (stride_x * stride_y);\n"
" let workgroup_id_x = workgroup_id_xy % stride_x;\n"
" let workgroup_id_y = workgroup_id_xy / stride_x;\n";
} else {
shader.MainFunctionBody() << " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n"
" let workgroup_id_x = workgroup_idx % stride;\n"
" let workgroup_id_y = workgroup_idx / stride;\n"
" let input_col = workgroup_id_y * tile_size + local_id.x;\n"
" let workgroup_id_y = workgroup_idx / stride;\n";
}
shader.MainFunctionBody() << " let input_col = workgroup_id_y * tile_size + local_id.x;\n"
" let input_row = workgroup_id_x * tile_size + local_id.y;\n"
" if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n"
<< " tile[local_id.y][local_id.x] = " << input.GetByIndices("a_indices_t(input_row, input_col)") << ";\n"
<< " if (" << input_cond_str << ") {\n"
<< " tile[local_id.y][local_id.x] = " << input.GetByIndices(input_str) << ";\n"
<< " }\n"
" workgroupBarrier();\n"
" let output_col = workgroup_id_x * tile_size + local_id.x;\n"
" let output_row = workgroup_id_y * tile_size + local_id.y;\n"
" if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n"
<< " " << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n"
<< " if (" << output_cond_str << ") {\n"
<< " " << output.SetByIndices(output_str, "tile[local_id.x][local_id.y]") << "\n"
<< " }";
} else {
shader.AdditionalImplementation() << "fn perm(i: output_indices_t)->a_indices_t {\n"
Expand Down Expand Up @@ -126,24 +141,25 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context,
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
const bool channels_last = new_perm == TensorShapeVector({2, 3, 1});
const bool channels_first = new_perm == TensorShapeVector({3, 1, 2});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
const bool map_first_channels_first = new_perm == TensorShapeVector({0, 2, 3, 1});
const bool map_last = new_perm == TensorShapeVector({1, 2, 3, 0});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first || map_first_channels_first || map_last;
auto new_input_shape = input_shape;
TensorShape new_output_shape(output_dims);

if (use_shared) {
new_input_shape = channels_last
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
: channels_first
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
: new_shape;
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
new_input_shape = channels_last ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
: channels_first ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
: (map_first_channels_first ? TensorShape({new_shape[0], new_shape[1], new_shape[2] * new_shape[3]})
: (map_last ? TensorShape({new_shape[0], new_shape[1] * new_shape[2] * new_shape[3]}) : new_shape));
new_output_shape = map_first_channels_first ? TensorShape({new_input_shape[0], new_input_shape[2], new_input_shape[1]}) : TensorShape({new_input_shape[1], new_input_shape[0]});
}

uint32_t output_size = onnxruntime::narrow<int32_t>(input_shape.Size());
TransposeProgram program{permutations, use_shared};
TransposeProgram program{permutations, use_shared, map_first_channels_first};

program
.CacheHint(absl::StrJoin(permutations, "-"))
.CacheHint(use_shared, map_first_channels_first, absl::StrJoin(permutations, "-"))
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.AddUniformVariables({
Expand All @@ -152,8 +168,9 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context,

if (use_shared) {
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)));
program.SetDispatchGroupSize(static_cast<uint32_t>(((map_first_channels_first ? new_output_shape[2] : new_output_shape[1]) + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((map_first_channels_first ? new_output_shape[1] : new_output_shape[0]) + TILE_SIZE - 1) / TILE_SIZE),
map_first_channels_first ? static_cast<uint32_t>(new_output_shape[0]) : 1);
} else {
program.SetWorkgroupSize(WORKGROUP_SIZE);

Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webgpu/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class Transpose final : public WebGpuKernel, public TransposeBase {

class TransposeProgram final : public Program<TransposeProgram> {
public:
TransposeProgram(const gsl::span<const size_t>& permutations, bool use_shared)
: Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) {
TransposeProgram(const gsl::span<const size_t>& permutations, bool use_shared, bool map_first_channels_first = false)
: Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared), map_first_channels_first_(map_first_channels_first) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -35,6 +35,7 @@ class TransposeProgram final : public Program<TransposeProgram> {
private:
InlinedVector<int64_t> perm_;
const bool use_shared_;
const bool map_first_channels_first_;
};

} // namespace webgpu
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/transpose_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,28 @@ TEST(TransposeOpTest, NDim) {
TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals2);
}

TEST(TransposeOpTest, 4Dim_perm2310) {
std::vector<int64_t> input_shape({2, 2, 2, 2});
std::vector<float> input_vals = {1.0f, 2.0f, 3.0f, 4.0f,
5.0f, 6.0f, 7.0f, 8.0f,
9.0f, 10.0f, 11.0f, 12.0f,
13.0f, 14.0f, 15.0f, 16.0f};

std::vector<int64_t> perm = {0, 2, 3, 1};
std::vector<float> expected_vals = {1.0f, 5.0f, 2.0f, 6.0f,
3.0f, 7.0f, 4.0f, 8.0f,
9.0f, 13.0f, 10.0f, 14.0f,
11.0f, 15.0f, 12.0f, 16.0f};
TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals);

perm = {1, 2, 3, 0};
std::vector<float> expected_vals2 = {1.0f, 9.0f, 5.0f, 13.0f,
2.0f, 10.0f, 6.0f, 14.0f,
3.0f, 11.0f, 7.0f, 15.0f,
4.0f, 12.0f, 8.0f, 16.0f};
TransposeTest(input_shape, expected_vals, &perm, input_shape, expected_vals2);
}

TEST(TransposeOpTest, DoTransposeImpl) {
std::vector<int64_t> input_shape({5, 2, 1, 3});
std::vector<float> input_vals(30);
Expand Down