Skip to content

Commit b8f742b

Browse files
authored
[TorchToLinalg] Add lowering of torch.aten.pixel_unshuffle op (#4278)
This PR will fix the following issue: [Add lowering of torch.aten.pixel_unshuffle op to linalg dialect](#4260) This code snippet can reproduce the issue: ``` func.func @pixel_unshuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} { %int2 = torch.constant.int 2 %0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32> return %0 : !torch.vtensor<[1,32,2,2],f32> } ``` The decomposition is based on this specification: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html and PyTorch implementation could be found in main/aten/src/ATen/native/PixelShuffle.cpp: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp With code changes, torch.aten.pixel_unshuffle will be lowered to the following: ``` module { func.func @main(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} { %int2 = torch.constant.int 2 %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %int3 = torch.constant.int 3 %int4 = torch.constant.int 4 %int5 = torch.constant.int 5 %0 = torch.prim.ListConstruct %int0, %int1, %int3, %int5, %int2, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> %1 = torch.prims.split_dim %arg0, %int2, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32> %2 = torch.prims.split_dim %1, %int4, %int2 : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32> %3 = torch.aten.permute %2, %0 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list<int> -> !torch.vtensor<[1,8,2,2,2,2],f32> %4 = torch.prims.collapse %3, %int2, %int3 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32> %5 = torch.prims.collapse %4, %int1, %int2 : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32> return %5 : !torch.vtensor<[1,32,2,2],f32> } } ```
1 parent 4f572c5 commit b8f742b

File tree

10 files changed

+428
-57
lines changed

10 files changed

+428
-57
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
86688668
}];
86698669
}
86708670

8671+
def Torch_AtenPixelUnshuffleOp : Torch_Op<"aten.pixel_unshuffle", [
8672+
AllowsTypeRefinement,
8673+
HasValueSemantics,
8674+
ReadOnly
8675+
]> {
8676+
let summary = "Generated op for `aten::pixel_unshuffle : (Tensor, int) -> (Tensor)`";
8677+
let arguments = (ins
8678+
AnyTorchTensorType:$self,
8679+
Torch_IntType:$downscale_factor
8680+
);
8681+
let results = (outs
8682+
AnyTorchOptionalTensorType:$result
8683+
);
8684+
let hasCustomAssemblyFormat = 1;
8685+
let extraClassDefinition = [{
8686+
ParseResult AtenPixelUnshuffleOp::parse(OpAsmParser &parser, OperationState &result) {
8687+
return parseDefaultTorchOp(parser, result, 2, 1);
8688+
}
8689+
void AtenPixelUnshuffleOp::print(OpAsmPrinter &printer) {
8690+
printDefaultTorchOp(printer, *this, 2, 1);
8691+
}
8692+
}];
8693+
}
8694+
86718695
def Torch_AtenChannelShuffleOp : Torch_Op<"aten.channel_shuffle", [
86728696
AllowsTypeRefinement,
86738697
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7613,6 +7613,56 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
76137613
" %15 = torch.aten.append.t %6, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
76147614
" return %6 : !torch.list<int>\n"
76157615
" }\n"
7616+
" func.func @\"__torch_mlir_shape_fn.aten.pixel_unshuffle\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
7617+
" %int1 = torch.constant.int 1\n"
7618+
" %int-3 = torch.constant.int -3\n"
7619+
" %str = torch.constant.str \"AssertionError: width must be divisible by downscale_factor in pixel_unshuffle\"\n"
7620+
" %int-1 = torch.constant.int -1\n"
7621+
" %str_0 = torch.constant.str \"AssertionError: height must be divisible by downscale_factor in pixel_unshuffle\"\n"
7622+
" %int-2 = torch.constant.int -2\n"
7623+
" %none = torch.constant.none\n"
7624+
" %str_1 = torch.constant.str \"AssertionError: input must be at least rank-3 in pixel_unshuffle\"\n"
7625+
" %int3 = torch.constant.int 3\n"
7626+
" %int0 = torch.constant.int 0\n"
7627+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
7628+
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
7629+
" torch.prim.If %1 -> () {\n"
7630+
" torch.prim.If.yield\n"
7631+
" } else {\n"
7632+
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
7633+
" torch.prim.If.yield\n"
7634+
" }\n"
7635+
" %2 = torch.aten.mul.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7636+
" %3 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
7637+
" %4 = torch.aten.remainder.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7638+
" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
7639+
" torch.prim.If %5 -> () {\n"
7640+
" torch.prim.If.yield\n"
7641+
" } else {\n"
7642+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
7643+
" torch.prim.If.yield\n"
7644+
" }\n"
7645+
" %6 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
7646+
" %7 = torch.aten.remainder.int %6, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7647+
" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
7648+
" torch.prim.If %8 -> () {\n"
7649+
" torch.prim.If.yield\n"
7650+
" } else {\n"
7651+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
7652+
" torch.prim.If.yield\n"
7653+
" }\n"
7654+
" %9 = torch.aten.slice.t %arg0, %int0, %int-3, %int1 : !torch.list<int>, !torch.int, !torch.int, !torch.int -> !torch.list<int>\n"
7655+
" %10 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int\n"
7656+
" %11 = torch.aten.mul.int %10, %2 : !torch.int, !torch.int -> !torch.int\n"
7657+
" %12 = torch.aten.append.t %9, %11 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
7658+
" %13 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
7659+
" %14 = torch.aten.floordiv.int %13, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7660+
" %15 = torch.aten.append.t %9, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
7661+
" %16 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
7662+
" %17 = torch.aten.floordiv.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n"
7663+
" %18 = torch.aten.append.t %9, %17 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
7664+
" return %9 : !torch.list<int>\n"
7665+
" }\n"
76167666
" func.func @\"__torch_mlir_shape_fn.aten.channel_shuffle\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
76177667
" %none = torch.constant.none\n"
76187668
" %str = torch.constant.str \"AssertionError: input must be at least rank-3 in channel_shuffle\"\n"
@@ -12411,6 +12461,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1241112461
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1241212462
" return %0#1 : !torch.int\n"
1241312463
" }\n"
12464+
" func.func @\"__torch_mlir_dtype_fn.aten.pixel_unshuffle\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12465+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12466+
" return %0#1 : !torch.int\n"
12467+
" }\n"
1241412468
" func.func @\"__torch_mlir_dtype_fn.aten.channel_shuffle\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
1241512469
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1241612470
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 154 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3536,30 +3536,6 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
35363536
};
35373537
} // namespace
35383538

3539-
namespace { // Start of rearrangement ops utility functions
3540-
// Extracts shape as vector of int64_t from vector of Value
3541-
SmallVector<int64_t> getIntShapeFromValues(ArrayRef<Value> vals) {
3542-
SmallVector<int64_t> shape;
3543-
shape.reserve(vals.size());
3544-
for (Value v : vals) {
3545-
int64_t cst_val;
3546-
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3547-
shape.push_back(cst_val);
3548-
} else {
3549-
shape.push_back(kUnknownSize);
3550-
}
3551-
}
3552-
return shape;
3553-
}
3554-
3555-
// Converts a vector of Value (shape dimensions) into a ValueTensorType
3556-
ValueTensorType getTypeFromShape(ArrayRef<Value> vals, Type inOptionalDType) {
3557-
SmallVector<int64_t> intShape = getIntShapeFromValues(vals);
3558-
return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape),
3559-
inOptionalDType);
3560-
}
3561-
} // namespace
3562-
35633539
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
35643540
// prims.collapse operations.
35653541
//
@@ -3609,18 +3585,9 @@ class DecomposeAtenPixelShuffleOp
36093585

36103586
auto nLeadingDims = inRank - 3;
36113587

3612-
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
3613-
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
3614-
// folded to a ConstantOp.
3615-
auto getDimSize = [&](uint64_t i) -> Value {
3616-
Value dim =
3617-
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3618-
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3619-
};
3620-
3621-
auto inC = getDimSize(inRank - 3);
3622-
auto inH = getDimSize(inRank - 2);
3623-
auto inW = getDimSize(inRank - 1);
3588+
auto inC = getTensorDimSize(rewriter, inValue, inRank - 3);
3589+
auto inH = getTensorDimSize(rewriter, inValue, inRank - 2);
3590+
auto inW = getTensorDimSize(rewriter, inValue, inRank - 1);
36243591

36253592
auto factor = op.getUpscaleFactor();
36263593

@@ -3678,23 +3645,26 @@ class DecomposeAtenPixelShuffleOp
36783645
auto partiallyExpanded =
36793646
rewriter
36803647
.create<PrimsSplitDimOp>(
3681-
loc, getTypeFromShape(partiallyExpandedShape, inOptionalDType),
3648+
loc,
3649+
getTensorTypeFromShapeValues(partiallyExpandedShape,
3650+
inOptionalDType),
36823651
inValue, dimensionConstants[nLeadingDims], outC)
36833652
.getResult();
36843653

36853654
// Split new dimension factorSquared -> (factor, factor)
36863655
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3687-
loc, getTypeFromShape(prePermuteShape, inOptionalDType),
3656+
loc, getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType),
36883657
partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor);
36893658

36903659
// Perform the permutation
36913660
auto permuted = rewriter.create<AtenPermuteOp>(
3692-
loc, getTypeFromShape(postPermuteShape, inOptionalDType), fullyExpanded,
3693-
permuteDimsOrder);
3661+
loc, getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType),
3662+
fullyExpanded, permuteDimsOrder);
36943663

36953664
// Collapse final 2 dimension
36963665
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3697-
loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType),
3666+
loc,
3667+
getTensorTypeFromShapeValues(partiallyCollapsedShape, inOptionalDType),
36983668
permuted, dimensionConstants[nLeadingDims + 3],
36993669
dimensionConstants[nLeadingDims + 4]);
37003670

@@ -3709,6 +3679,142 @@ class DecomposeAtenPixelShuffleOp
37093679
};
37103680
} // namespace
37113681

3682+
// Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and
3683+
// prims.collapse operations.
3684+
//
3685+
// We want to do the exact opposite of aten.pixel_shuffle
3686+
//
3687+
// 'r' is referred to as the 'downscale factor' or just 'factor' below.
3688+
//
3689+
// If input is a tensor of shape
3690+
// (*leading_dims, C, H*r, W*r),
3691+
//
3692+
// where leading_dims is of size N, then
3693+
// X = pixel_unshuffle(input, downscale_factor)
3694+
//
3695+
// gets replaced with
3696+
// X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r)
3697+
// X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r)
3698+
// X = X.permute(0, ..., N, N+2, N+4, N+1, N+3)
3699+
// # shape (*leading_dims, C, r, r, H, W)
3700+
// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W)
3701+
//
3702+
namespace {
3703+
class DecomposeAtenPixelUnshuffleOp
3704+
: public OpRewritePattern<AtenPixelUnshuffleOp> {
3705+
public:
3706+
using OpRewritePattern::OpRewritePattern;
3707+
LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op,
3708+
PatternRewriter &rewriter) const override {
3709+
3710+
Location loc = op.getLoc();
3711+
Value inValue = op.getSelf();
3712+
auto inType = cast<BaseTensorType>(inValue.getType());
3713+
auto maybeSizes = inType.getOptionalSizes();
3714+
if (!maybeSizes) {
3715+
return rewriter.notifyMatchFailure(
3716+
op, "Expected input tensor to have known rank.");
3717+
}
3718+
auto inShape = maybeSizes.value();
3719+
auto inRank = inShape.size();
3720+
3721+
// The input tensor must have at least 3 dimensions: (1) the channel
3722+
// dimension which gets bigger by 'factor*factor', (2) the H channel which
3723+
// gets smaller by 'factor' and (3) the W channel which get smaller by
3724+
// 'factor'. The total number of dimensions is 3 + N, where N is the number
3725+
// of leading dimensions, and N >= 0 so the input must have rank at least 3.
3726+
if (inRank < 3)
3727+
return rewriter.notifyMatchFailure(
3728+
op, "Expected input tensor to have rank greater than 2.");
3729+
3730+
const auto inOptionalDType = inType.getOptionalDtype();
3731+
3732+
auto nLeadingDims = inRank - 3;
3733+
3734+
auto inC = getTensorDimSize(rewriter, inValue, inRank - 3);
3735+
auto inH = getTensorDimSize(rewriter, inValue, inRank - 2);
3736+
auto inW = getTensorDimSize(rewriter, inValue, inRank - 1);
3737+
3738+
auto factor = op.getDownscaleFactor();
3739+
3740+
Value factorSquared =
3741+
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
3742+
3743+
Value outC = rewriter.createOrFold<AtenMulIntOp>(loc, inC, factorSquared);
3744+
3745+
Value outH = rewriter.createOrFold<AtenFloordivIntOp>(loc, inH, factor);
3746+
Value outW = rewriter.createOrFold<AtenFloordivIntOp>(loc, inW, factor);
3747+
3748+
SmallVector<Value> dimensionConstants;
3749+
dimensionConstants.reserve(inRank + 2);
3750+
for (unsigned i = 0; i < inRank + 2; ++i) {
3751+
dimensionConstants.push_back(
3752+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3753+
}
3754+
3755+
SmallVector<Value> leadingDims;
3756+
leadingDims.reserve(nLeadingDims);
3757+
for (unsigned i = 0; i < nLeadingDims; ++i) {
3758+
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3759+
loc, inValue, dimensionConstants[i]);
3760+
leadingDims.push_back(leadingDimSize);
3761+
}
3762+
3763+
SmallVector<Value> prePermuteShape = leadingDims;
3764+
prePermuteShape.append({inC, outH, factor, outW, factor});
3765+
3766+
SmallVector<Value> postPermuteShape = leadingDims;
3767+
postPermuteShape.append({inC, factor, factor, outH, outW});
3768+
3769+
SmallVector<Value> partiallyCollapsedShape = leadingDims;
3770+
partiallyCollapsedShape.append({inC, factorSquared, outH, outW});
3771+
3772+
SmallVector<Value> outShape = leadingDims;
3773+
outShape.append({outC, outH, outW});
3774+
3775+
SmallVector<Value> permutation{dimensionConstants.begin(),
3776+
dimensionConstants.begin() + nLeadingDims};
3777+
SmallVector<uint64_t> permutationTail{0, 2, 4, 1, 3};
3778+
for (uint64_t d : permutationTail) {
3779+
permutation.push_back(dimensionConstants[nLeadingDims + d]);
3780+
}
3781+
3782+
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3783+
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3784+
permutation);
3785+
3786+
SmallVector<Value> heightSplitShape = leadingDims;
3787+
heightSplitShape.append({inC, outH, factor, inW});
3788+
3789+
// Split input channel inH -> (outH, factor)
3790+
auto partiallyExpanded =
3791+
rewriter
3792+
.create<PrimsSplitDimOp>(
3793+
loc,
3794+
getTensorTypeFromShapeValues(heightSplitShape, inOptionalDType),
3795+
inValue, dimensionConstants[nLeadingDims + 1], outH)
3796+
.getResult();
3797+
3798+
// Split new dimension inW -> (outW, factor)
3799+
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3800+
loc, getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType),
3801+
partiallyExpanded, dimensionConstants[nLeadingDims + 3], outW);
3802+
3803+
// Perform the permutation
3804+
auto permuted = rewriter.create<AtenPermuteOp>(
3805+
loc, getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType),
3806+
fullyExpanded, permuteDimsOrder);
3807+
3808+
// Collapse final 2 dimensions back to original rank
3809+
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
3810+
op, op.getType(), permuted, dimensionConstants[nLeadingDims],
3811+
dimensionConstants[nLeadingDims + 2]);
3812+
3813+
return success();
3814+
}
3815+
};
3816+
} // namespace
3817+
37123818
// Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
37133819
// prims.collapse operations.
37143820
//
@@ -3763,23 +3869,14 @@ class DecomposeAtenChannelShuffleOp
37633869

37643870
auto numOfSpatialDims = inRank - 2;
37653871

3766-
// Get the size of the dimension 'i'. Note the use of 'createOrFold'
3767-
// instead of 'create': if the dimension size is known, then the
3768-
// AtenSizeIntOp is folded to a ConstantOp.
3769-
auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value {
3770-
Value dim =
3771-
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3772-
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3773-
};
3774-
37753872
// The channel dimension is always the second dimension. PyTorch errors out
37763873
// if the batch dimension (first dimension) is not present. See comment at
37773874
// the top of this class for details.
3778-
auto inC = getDimSize(1);
3875+
auto inC = getTensorDimSize(rewriter, inValue, 1);
37793876
SmallVector<Value> inSpatialDims;
37803877
inSpatialDims.reserve(numOfSpatialDims);
37813878
for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) {
3782-
inSpatialDims.push_back(getDimSize(i));
3879+
inSpatialDims.push_back(getTensorDimSize(rewriter, inValue, i));
37833880
}
37843881

37853882
auto groups = op.getGroups();
@@ -3832,14 +3929,14 @@ class DecomposeAtenChannelShuffleOp
38323929
auto expandedTensor =
38333930
rewriter
38343931
.create<PrimsSplitDimOp>(
3835-
loc, getTypeFromShape(splitShape, inOptionalDType), inValue,
3836-
dimC, tempC)
3932+
loc, getTensorTypeFromShapeValues(splitShape, inOptionalDType),
3933+
inValue, dimC, tempC)
38373934
.getResult();
38383935

38393936
// Perform the permutation
38403937
auto permuted = rewriter.create<AtenPermuteOp>(
3841-
loc, getTypeFromShape(permuteShape, inOptionalDType), expandedTensor,
3842-
permuteDimsOrder);
3938+
loc, getTensorTypeFromShapeValues(permuteShape, inOptionalDType),
3939+
expandedTensor, permuteDimsOrder);
38433940

38443941
// Collapse (C, groups) back into a single channel dimension
38453942
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(op, op.getType(), permuted,
@@ -12909,6 +13006,7 @@ class DecomposeComplexOpsPass
1290913006
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
1291013007
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
1291113008
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
13009+
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelUnshuffleOp>(patterns);
1291213010
addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
1291313011
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
1291413012
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
421421
target.addIllegalOp<Aten_LinalgDetOp>();
422422
target.addIllegalOp<AtenLinalgSlogdetOp>();
423423
target.addIllegalOp<AtenPixelShuffleOp>();
424+
target.addIllegalOp<AtenPixelUnshuffleOp>();
424425
target.addIllegalOp<AtenChannelShuffleOp>();
425426
target.addIllegalOp<AtenTOp>();
426427
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ bool Torch::isViewLikeOp(Operation *op) {
327327
AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
328328
PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp,
329329
AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp,
330-
AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
330+
AtenPixelUnshuffleOp, AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(
331+
op);
331332
}
332333

333334
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,

0 commit comments

Comments
 (0)