Skip to content

Commit fdebf29

Browse files
[TOSA] Handle valid padding in aten.convolution
1 parent bc657db commit fdebf29

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,8 +2380,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23802380
// padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23812381
// spatial direction, implying the same top=bottom=height and left=right=width
23822382
// values for TOSA.
2383-
SmallVector<int64_t> padding(
2384-
{padding_2d[0], padding_2d[0], padding_2d[1], padding_2d[1]});
2383+
2384+
int64_t padH = padding_2d[0];
2385+
// When padding is 'Valid', Torch produces 1D padding with only one value.
2386+
int64_t padW = (padding_2d.size() > 1) ? padding_2d[1] : padding_2d[0];
2387+
2388+
SmallVector<int64_t> padding({padH, padH, padW, padW});
23852389

23862390
SmallVector<int64_t, 2> dilation;
23872391
if (!matchPattern(adaptor.getDilation(), m_TorchListOfConstantInts(dilation)))

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3891,6 +3891,52 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
38913891

38923892
// -----
38933893

3894+
// CHECK-LABEL: func.func @torch.aten.convolution$valid_padding(
3895+
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> {
3896+
// CHECK: %[[INPUT_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[INPUT_TENSOR]] : !torch.vtensor<[1,1,5,5],f32> -> tensor<1x1x5x5xf32>
3897+
// CHECK: %[[WEIGHT_CONST:.*]] = "tosa.const"() <{values = dense<-7.486820e-03> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
3898+
// CHECK: %[[BIAS_CONST:.*]] = "tosa.const"() <{values = dense<0.536443591> : tensor<1xf32>}> : () -> tensor<1xf32>
3899+
// CHECK: %[[STRIDE_H:.*]] = torch.constant.int 1
3900+
// CHECK: %[[STRIDE_W:.*]] = torch.constant.int 1
3901+
// CHECK: %[[STRIDES_LIST:.*]] = torch.prim.ListConstruct %[[STRIDE_H]], %[[STRIDE_W]] : (!torch.int, !torch.int) -> !torch.list<int>
3902+
// CHECK: %[[PADDING_VAL:.*]] = torch.constant.int 0
3903+
// CHECK: %[[PADDING_LIST:.*]] = torch.prim.ListConstruct %[[PADDING_VAL]] : (!torch.int) -> !torch.list<int>
3904+
// CHECK: %[[DILATION_H:.*]] = torch.constant.int 1
3905+
// CHECK: %[[DILATION_W:.*]] = torch.constant.int 1
3906+
// CHECK: %[[DILATIONS_LIST:.*]] = torch.prim.ListConstruct %[[DILATION_H]], %[[DILATION_W]] : (!torch.int, !torch.int) -> !torch.list<int>
3907+
// CHECK: %[[GROUPS_VAL:.*]] = torch.constant.bool false
3908+
// CHECK: %[[OUTPUT_PADDING_VAL:.*]] = torch.constant.int 0
3909+
// CHECK: %[[OUTPUT_PADDING_LIST:.*]] = torch.prim.ListConstruct %[[OUTPUT_PADDING_VAL]] : (!torch.int) -> !torch.list<int>
3910+
// CHECK: %[[CONV_DIMENSIONS:.*]] = torch.constant.int 1
3911+
// CHECK: %[[WEIGHT_TRANSPOSED:.*]] = tosa.transpose %[[WEIGHT_CONST]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
3912+
// CHECK: %[[INPUT_TRANSPOSED:.*]] = tosa.transpose %[[INPUT_BUILTIN]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x1x5x5xf32>) -> tensor<1x5x5x1xf32>
3913+
// CHECK: %[[ZERO_BIAS_OP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3914+
// CHECK: %[[ZERO_BIAS_OP_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3915+
// CHECK: %[[CONV_RESULT_TOSA:.*]] = tosa.conv2d %[[INPUT_TRANSPOSED]], %[[WEIGHT_TRANSPOSED]], %[[BIAS_CONST]], %[[ZERO_BIAS_OP]], %[[ZERO_BIAS_OP_2]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x5x5x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x1xf32>
3916+
// CHECK: %[[OUTPUT_TRANSPOSED:.*]] = tosa.transpose %[[CONV_RESULT_TOSA]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x5x5x1xf32>) -> tensor<1x1x5x5xf32>
3917+
// CHECK: %[[OUTPUT_TENSOR:.*]] = torch_c.from_builtin_tensor %[[OUTPUT_TRANSPOSED]] : tensor<1x1x5x5xf32> -> !torch.vtensor<[1,1,5,5],f32>
3918+
// CHECK: return %[[OUTPUT_TENSOR]] : !torch.vtensor<[1,1,5,5],f32>
3919+
func.func @torch.aten.convolution$valid_padding(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> {
3920+
%0 = torch.vtensor.literal(dense<-7.486820e-03> : tensor<1x1x1x1xf32>) : !torch.vtensor<[1,1,1,1],f32>
3921+
%1 = torch.vtensor.literal(dense<0.536443591> : tensor<1xf32>) : !torch.vtensor<[1],f32>
3922+
%int1 = torch.constant.int 1
3923+
%int1_0 = torch.constant.int 1
3924+
%2 = torch.prim.ListConstruct %int1, %int1_0 : (!torch.int, !torch.int) -> !torch.list<int>
3925+
%int0 = torch.constant.int 0
3926+
%3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
3927+
%int1_1 = torch.constant.int 1
3928+
%int1_2 = torch.constant.int 1
3929+
%4 = torch.prim.ListConstruct %int1_1, %int1_2 : (!torch.int, !torch.int) -> !torch.list<int>
3930+
%false = torch.constant.bool false
3931+
%int0_3 = torch.constant.int 0
3932+
%5 = torch.prim.ListConstruct %int0_3 : (!torch.int) -> !torch.list<int>
3933+
%int1_4 = torch.constant.int 1
3934+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %false, %5, %int1_4 : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,5,5],f32>
3935+
return %6 : !torch.vtensor<[1,1,5,5],f32>
3936+
}
3937+
3938+
// -----
3939+
38943940
// CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(
38953941
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,1,56,56],f32>) -> !torch.vtensor<[1,1,27,27],f32> {
38963942
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,56,56],f32> -> tensor<1x1x56x56xf32>

0 commit comments

Comments
 (0)