@@ -3536,30 +3536,6 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
3536
3536
};
3537
3537
} // namespace
3538
3538
3539
- namespace { // Start of rearrangement ops utility functions
3540
- // Extracts shape as vector of int64_t from vector of Value
3541
- SmallVector<int64_t> getIntShapeFromValues(ArrayRef<Value> vals) {
3542
- SmallVector<int64_t> shape;
3543
- shape.reserve(vals.size());
3544
- for (Value v : vals) {
3545
- int64_t cst_val;
3546
- if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3547
- shape.push_back(cst_val);
3548
- } else {
3549
- shape.push_back(kUnknownSize);
3550
- }
3551
- }
3552
- return shape;
3553
- }
3554
-
3555
- // Converts a vector of Value (shape dimensions) into a ValueTensorType
3556
- ValueTensorType getTypeFromShape(ArrayRef<Value> vals, Type inOptionalDType) {
3557
- SmallVector<int64_t> intShape = getIntShapeFromValues(vals);
3558
- return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape),
3559
- inOptionalDType);
3560
- }
3561
- } // namespace
3562
-
3563
3539
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
3564
3540
// prims.collapse operations.
3565
3541
//
@@ -3609,18 +3585,9 @@ class DecomposeAtenPixelShuffleOp
3609
3585
3610
3586
auto nLeadingDims = inRank - 3;
3611
3587
3612
- // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
3613
- // of 'create': if the dimension size is known, then the AtenSizeIntOp is
3614
- // folded to a ConstantOp.
3615
- auto getDimSize = [&](uint64_t i) -> Value {
3616
- Value dim =
3617
- rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3618
- return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3619
- };
3620
-
3621
- auto inC = getDimSize(inRank - 3);
3622
- auto inH = getDimSize(inRank - 2);
3623
- auto inW = getDimSize(inRank - 1);
3588
+ auto inC = getTensorDimSize(rewriter, inValue, inRank - 3);
3589
+ auto inH = getTensorDimSize(rewriter, inValue, inRank - 2);
3590
+ auto inW = getTensorDimSize(rewriter, inValue, inRank - 1);
3624
3591
3625
3592
auto factor = op.getUpscaleFactor();
3626
3593
@@ -3678,23 +3645,26 @@ class DecomposeAtenPixelShuffleOp
3678
3645
auto partiallyExpanded =
3679
3646
rewriter
3680
3647
.create<PrimsSplitDimOp>(
3681
- loc, getTypeFromShape(partiallyExpandedShape, inOptionalDType),
3648
+ loc,
3649
+ getTensorTypeFromShapeValues(partiallyExpandedShape,
3650
+ inOptionalDType),
3682
3651
inValue, dimensionConstants[nLeadingDims], outC)
3683
3652
.getResult();
3684
3653
3685
3654
// Split new dimension factorSquared -> (factor, factor)
3686
3655
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3687
- loc, getTypeFromShape (prePermuteShape, inOptionalDType),
3656
+ loc, getTensorTypeFromShapeValues (prePermuteShape, inOptionalDType),
3688
3657
partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor);
3689
3658
3690
3659
// Perform the permutation
3691
3660
auto permuted = rewriter.create<AtenPermuteOp>(
3692
- loc, getTypeFromShape (postPermuteShape, inOptionalDType), fullyExpanded ,
3693
- permuteDimsOrder);
3661
+ loc, getTensorTypeFromShapeValues (postPermuteShape, inOptionalDType),
3662
+ fullyExpanded, permuteDimsOrder);
3694
3663
3695
3664
// Collapse final 2 dimension
3696
3665
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3697
- loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType),
3666
+ loc,
3667
+ getTensorTypeFromShapeValues(partiallyCollapsedShape, inOptionalDType),
3698
3668
permuted, dimensionConstants[nLeadingDims + 3],
3699
3669
dimensionConstants[nLeadingDims + 4]);
3700
3670
@@ -3709,6 +3679,142 @@ class DecomposeAtenPixelShuffleOp
3709
3679
};
3710
3680
} // namespace
3711
3681
3682
+ // Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and
3683
+ // prims.collapse operations.
3684
+ //
3685
+ // We want to do the exact opposite of aten.pixel_shuffle
3686
+ //
3687
+ // 'r' is referred to as the 'downscale factor' or just 'factor' below.
3688
+ //
3689
+ // If input is a tensor of shape
3690
+ // (*leading_dims, C, H*r, W*r),
3691
+ //
3692
+ // where leading_dims is of size N, then
3693
+ // X = pixel_unshuffle(input, downscale_factor)
3694
+ //
3695
+ // gets replaced with
3696
+ // X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r)
3697
+ // X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r)
3698
+ // X = X.permute(0, ..., N, N+2, N+4, N+1, N+3)
3699
+ // # shape (*leading_dims, C, r, r, H, W)
3700
+ // X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W)
3701
+ //
3702
+ namespace {
3703
+ class DecomposeAtenPixelUnshuffleOp
3704
+ : public OpRewritePattern<AtenPixelUnshuffleOp> {
3705
+ public:
3706
+ using OpRewritePattern::OpRewritePattern;
3707
+ LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op,
3708
+ PatternRewriter &rewriter) const override {
3709
+
3710
+ Location loc = op.getLoc();
3711
+ Value inValue = op.getSelf();
3712
+ auto inType = cast<BaseTensorType>(inValue.getType());
3713
+ auto maybeSizes = inType.getOptionalSizes();
3714
+ if (!maybeSizes) {
3715
+ return rewriter.notifyMatchFailure(
3716
+ op, "Expected input tensor to have known rank.");
3717
+ }
3718
+ auto inShape = maybeSizes.value();
3719
+ auto inRank = inShape.size();
3720
+
3721
+ // The input tensor must have at least 3 dimensions: (1) the channel
3722
+ // dimension which gets bigger by 'factor*factor', (2) the H channel which
3723
+ // gets smaller by 'factor' and (3) the W channel which get smaller by
3724
+ // 'factor'. The total number of dimensions is 3 + N, where N is the number
3725
+ // of leading dimensions, and N >= 0 so the input must have rank at least 3.
3726
+ if (inRank < 3)
3727
+ return rewriter.notifyMatchFailure(
3728
+ op, "Expected input tensor to have rank greater than 2.");
3729
+
3730
+ const auto inOptionalDType = inType.getOptionalDtype();
3731
+
3732
+ auto nLeadingDims = inRank - 3;
3733
+
3734
+ auto inC = getTensorDimSize(rewriter, inValue, inRank - 3);
3735
+ auto inH = getTensorDimSize(rewriter, inValue, inRank - 2);
3736
+ auto inW = getTensorDimSize(rewriter, inValue, inRank - 1);
3737
+
3738
+ auto factor = op.getDownscaleFactor();
3739
+
3740
+ Value factorSquared =
3741
+ rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
3742
+
3743
+ Value outC = rewriter.createOrFold<AtenMulIntOp>(loc, inC, factorSquared);
3744
+
3745
+ Value outH = rewriter.createOrFold<AtenFloordivIntOp>(loc, inH, factor);
3746
+ Value outW = rewriter.createOrFold<AtenFloordivIntOp>(loc, inW, factor);
3747
+
3748
+ SmallVector<Value> dimensionConstants;
3749
+ dimensionConstants.reserve(inRank + 2);
3750
+ for (unsigned i = 0; i < inRank + 2; ++i) {
3751
+ dimensionConstants.push_back(
3752
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3753
+ }
3754
+
3755
+ SmallVector<Value> leadingDims;
3756
+ leadingDims.reserve(nLeadingDims);
3757
+ for (unsigned i = 0; i < nLeadingDims; ++i) {
3758
+ Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3759
+ loc, inValue, dimensionConstants[i]);
3760
+ leadingDims.push_back(leadingDimSize);
3761
+ }
3762
+
3763
+ SmallVector<Value> prePermuteShape = leadingDims;
3764
+ prePermuteShape.append({inC, outH, factor, outW, factor});
3765
+
3766
+ SmallVector<Value> postPermuteShape = leadingDims;
3767
+ postPermuteShape.append({inC, factor, factor, outH, outW});
3768
+
3769
+ SmallVector<Value> partiallyCollapsedShape = leadingDims;
3770
+ partiallyCollapsedShape.append({inC, factorSquared, outH, outW});
3771
+
3772
+ SmallVector<Value> outShape = leadingDims;
3773
+ outShape.append({outC, outH, outW});
3774
+
3775
+ SmallVector<Value> permutation{dimensionConstants.begin(),
3776
+ dimensionConstants.begin() + nLeadingDims};
3777
+ SmallVector<uint64_t> permutationTail{0, 2, 4, 1, 3};
3778
+ for (uint64_t d : permutationTail) {
3779
+ permutation.push_back(dimensionConstants[nLeadingDims + d]);
3780
+ }
3781
+
3782
+ Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3783
+ loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3784
+ permutation);
3785
+
3786
+ SmallVector<Value> heightSplitShape = leadingDims;
3787
+ heightSplitShape.append({inC, outH, factor, inW});
3788
+
3789
+ // Split input channel inH -> (outH, factor)
3790
+ auto partiallyExpanded =
3791
+ rewriter
3792
+ .create<PrimsSplitDimOp>(
3793
+ loc,
3794
+ getTensorTypeFromShapeValues(heightSplitShape, inOptionalDType),
3795
+ inValue, dimensionConstants[nLeadingDims + 1], outH)
3796
+ .getResult();
3797
+
3798
+ // Split new dimension inW -> (outW, factor)
3799
+ auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3800
+ loc, getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType),
3801
+ partiallyExpanded, dimensionConstants[nLeadingDims + 3], outW);
3802
+
3803
+ // Perform the permutation
3804
+ auto permuted = rewriter.create<AtenPermuteOp>(
3805
+ loc, getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType),
3806
+ fullyExpanded, permuteDimsOrder);
3807
+
3808
+ // Collapse final 2 dimensions back to original rank
3809
+ rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
3810
+ op, op.getType(), permuted, dimensionConstants[nLeadingDims],
3811
+ dimensionConstants[nLeadingDims + 2]);
3812
+
3813
+ return success();
3814
+ }
3815
+ };
3816
+ } // namespace
3817
+
3712
3818
// Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
3713
3819
// prims.collapse operations.
3714
3820
//
@@ -3763,23 +3869,14 @@ class DecomposeAtenChannelShuffleOp
3763
3869
3764
3870
auto numOfSpatialDims = inRank - 2;
3765
3871
3766
- // Get the size of the dimension 'i'. Note the use of 'createOrFold'
3767
- // instead of 'create': if the dimension size is known, then the
3768
- // AtenSizeIntOp is folded to a ConstantOp.
3769
- auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value {
3770
- Value dim =
3771
- rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3772
- return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3773
- };
3774
-
3775
3872
// The channel dimension is always the second dimension. PyTorch errors out
3776
3873
// if the batch dimension (first dimension) is not present. See comment at
3777
3874
// the top of this class for details.
3778
- auto inC = getDimSize( 1);
3875
+ auto inC = getTensorDimSize(rewriter, inValue, 1);
3779
3876
SmallVector<Value> inSpatialDims;
3780
3877
inSpatialDims.reserve(numOfSpatialDims);
3781
3878
for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) {
3782
- inSpatialDims.push_back(getDimSize( i));
3879
+ inSpatialDims.push_back(getTensorDimSize(rewriter, inValue, i));
3783
3880
}
3784
3881
3785
3882
auto groups = op.getGroups();
@@ -3832,14 +3929,14 @@ class DecomposeAtenChannelShuffleOp
3832
3929
auto expandedTensor =
3833
3930
rewriter
3834
3931
.create<PrimsSplitDimOp>(
3835
- loc, getTypeFromShape (splitShape, inOptionalDType), inValue ,
3836
- dimC, tempC)
3932
+ loc, getTensorTypeFromShapeValues (splitShape, inOptionalDType),
3933
+ inValue, dimC, tempC)
3837
3934
.getResult();
3838
3935
3839
3936
// Perform the permutation
3840
3937
auto permuted = rewriter.create<AtenPermuteOp>(
3841
- loc, getTypeFromShape (permuteShape, inOptionalDType), expandedTensor ,
3842
- permuteDimsOrder);
3938
+ loc, getTensorTypeFromShapeValues (permuteShape, inOptionalDType),
3939
+ expandedTensor, permuteDimsOrder);
3843
3940
3844
3941
// Collapse (C, groups) back into a single channel dimension
3845
3942
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(op, op.getType(), permuted,
@@ -12909,6 +13006,7 @@ class DecomposeComplexOpsPass
12909
13006
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
12910
13007
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
12911
13008
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
13009
+ addPatternIfTargetOpIsIllegal<DecomposeAtenPixelUnshuffleOp>(patterns);
12912
13010
addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
12913
13011
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
12914
13012
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
0 commit comments