Skip to content

Commit 7761e22

Browse files
committed
fix review comments
Change-Id: I8a211fa3f0468db1765ce57447a8dd422431067f
1 parent fe0e18c commit 7761e22

File tree

3 files changed

+51
-42
lines changed

3 files changed

+51
-42
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6123,8 +6123,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61236123
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
61246124
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
61256125
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6126-
DenseI64ArrayAttr &pad,
6127-
SmallVectorImpl<int64_t> *explicitNHWCPad = nullptr) {
6126+
DenseI64ArrayAttr &pad, SmallVectorImpl<int64_t> &explicitNHWCPad) {
61286127

61296128
RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
61306129
if (!inputTy)
@@ -6173,14 +6172,9 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61736172
m_TorchConstantBool(&countIncludePad)) ||
61746173

61756174
countIncludePad)) {
6176-
if (!explicitNHWCPad)
6177-
return rewriter.notifyMatchFailure(
6178-
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6179-
"`count_include_pad` value should be `False`.");
6180-
61816175
// Remember the spatial padding so we can emit an NHWC tosa.pad right
61826176
// after the transpose.
6183-
explicitNHWCPad->assign(
6177+
explicitNHWCPad.assign(
61846178
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
61856179

61866180
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
@@ -6193,7 +6187,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61936187
// the extra zeros supplied by the explicit pad.
61946188
SmallVector<int64_t> paddedShape(inputTy.getShape().begin(),
61956189
inputTy.getShape().end());
6196-
// Height stored at rank-2, width at rank-1 for NCHW shapes.
6190+
// Height stored at rank-2 and width at rank-1 while the tensor is still
6191+
// in NCHW order; the NHWC transpose happens later.
61976192
paddedShape[inputRank - 2] =
61986193
addPad(paddedShape[inputRank - 2], paddingInts[0], paddingInts[0]);
61996194
paddedShape[inputRank - 1] =
@@ -6223,6 +6218,18 @@ static LogicalResult getOutputTypeAndPoolingParameters(
62236218
return success();
62246219
}
62256220

6221+
template <typename AtenOpT, typename tosaOp>
6222+
static LogicalResult getOutputTypeAndPoolingParameters(
6223+
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
6224+
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
6225+
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6226+
DenseI64ArrayAttr &pad) {
6227+
SmallVector<int64_t, 4> ignoredExplicitPad;
6228+
return getOutputTypeAndPoolingParameters<AtenOpT, tosaOp>(
6229+
op, rewriter, inputXchw, dilationArray, outputTy, kernel, stride, pad,
6230+
ignoredExplicitPad);
6231+
}
6232+
62266233
class ConvertAtenMaxPool2dOp
62276234
: public ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp> {
62286235
public:
@@ -6348,7 +6355,7 @@ class ConvertAtenAvgPool2dOp
63486355
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
63496356
tosa::AvgPool2dOp>(
63506357
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad,
6351-
&explicitNHWCPad)))
6358+
explicitNHWCPad)))
63526359
return rewriter.notifyMatchFailure(
63536360
op, "invalid pooling parameters or input type");
63546361

@@ -6407,7 +6414,7 @@ class ConvertAtenAvgPool1dOp
64076414
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
64086415
tosa::AvgPool2dOp>(
64096416
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6410-
pad, &explicitNHWCPad)))
6417+
pad, explicitNHWCPad)))
64116418
return rewriter.notifyMatchFailure(
64126419
op, "invalid pooling parameters or input type");
64136420

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,9 @@ Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
636636
0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0};
637637
Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding);
638638

639-
auto inputTy = cast<RankedTensorType>(inputNHWC.getType());
639+
auto inputTy = dyn_cast<RankedTensorType>(inputNHWC.getType());
640+
if (!inputTy)
641+
return inputNHWC;
640642
SmallVector<int64_t, 4> resultShape(inputTy.getShape().begin(),
641643
inputTy.getShape().end());
642644
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4354,8 +4354,8 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
43544354

43554355
// -----
43564356
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4357-
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4358-
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4357+
// CHECK-SAME: %[[ARG_INPUT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4358+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG_INPUT]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
43594359
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
43604360
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
43614361
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
@@ -4365,17 +4365,17 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
43654365
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
43664366
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
43674367
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4368-
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
4369-
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4370-
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4371-
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
4372-
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4373-
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4374-
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4375-
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4376-
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4377-
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4378-
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4368+
// CHECK: %[[NHWC_TRANSPOSE:.*]] = tosa.transpose %[[INPUT_TENSOR]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
4369+
// CHECK: %[[PADDING_SHAPE:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4370+
// CHECK: %[[PAD_FILL:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4371+
// CHECK: %[[PADDED_NHWC:.*]] = tosa.pad %[[NHWC_TRANSPOSE]], %[[PADDING_SHAPE]], %[[PAD_FILL]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
4372+
// CHECK: %[[AVG_POOL_LHS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4373+
// CHECK: %[[AVG_POOL_RHS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4374+
// CHECK: %[[AVG_POOL_RESULT:.*]] = tosa.avg_pool2d %[[PADDED_NHWC]], %[[AVG_POOL_LHS_ZP]], %[[AVG_POOL_RHS_ZP]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4375+
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[AVG_POOL_RESULT]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4376+
// CHECK: %[[RESULT_CAST:.*]] = tensor.cast %[[RESULT_NCHW]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4377+
// CHECK: %[[TORCH_RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4378+
// CHECK: return %[[TORCH_RESULT]] : !torch.vtensor<[1,192,35,35],f32>
43794379
// CHECK: }
43804380
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
43814381
%int0 = torch.constant.int 0
@@ -4394,30 +4394,30 @@ func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,
43944394

43954395
// -----
43964396
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4397-
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4398-
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4397+
// CHECK-SAME: %[[ARG_INPUT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4398+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG_INPUT]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
43994399
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
44004400
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
44014401
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
44024402
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
44034403
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
44044404
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
44054405
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4406-
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4407-
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4408-
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_10]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
4409-
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4410-
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4411-
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
4412-
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4413-
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4414-
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4415-
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4416-
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4417-
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4418-
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4419-
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4420-
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4406+
// CHECK: %[[RESHAPE_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4407+
// CHECK: %[[RESHAPED_INPUT:.*]] = tosa.reshape %[[INPUT_TENSOR]], %[[RESHAPE_SHAPE]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4408+
// CHECK: %[[TRANSPOSED_NHWC:.*]] = tosa.transpose %[[RESHAPED_INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
4409+
// CHECK: %[[PADDING_SHAPE:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4410+
// CHECK: %[[PAD_FILL:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4411+
// CHECK: %[[PADDED_NHWC:.*]] = tosa.pad %[[TRANSPOSED_NHWC]], %[[PADDING_SHAPE]], %[[PAD_FILL]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
4412+
// CHECK: %[[AVG_POOL_LHS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4413+
// CHECK: %[[AVG_POOL_RHS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4414+
// CHECK: %[[AVG_POOL_RESULT:.*]] = tosa.avg_pool2d %[[PADDED_NHWC]], %[[AVG_POOL_LHS_ZP]], %[[AVG_POOL_RHS_ZP]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4415+
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[AVG_POOL_RESULT]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4416+
// CHECK: %[[RESHAPE_BACK_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4417+
// CHECK: %[[RESHAPED_BACK:.*]] = tosa.reshape %[[RESULT_NCHW]], %[[RESHAPE_BACK_SHAPE]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4418+
// CHECK: %[[RESULT_CAST:.*]] = tensor.cast %[[RESHAPED_BACK]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4419+
// CHECK: %[[TORCH_RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4420+
// CHECK: return %[[TORCH_RESULT]] : !torch.vtensor<[1,512,10],f32>
44214421
// CHECK: }
44224422
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
44234423
%int1 = torch.constant.int 1

0 commit comments

Comments
 (0)