Skip to content

Commit 2665ed3

Browse files
authored
adds a few common patterns to scalarize shapes pass (#3779)
This patch adds two things: 1. support for folding scalar patterns like [1]---squeeze--->[] ---unsqueeze--->[1]. 2. a canonicalizer for aten.view that applies when we can statically or dynamically (through the scalarized view shapes) infer that it is a flatten or unflatten op in the last dim. I'm not sure if this is the right place to be adding such a view canonicalizer. Catastrophically, there is a decomposition from flatten and unflatten into aten.view. Until this gets deleted (and it definitely should be deleted), I felt like this would be an appropriate temporary home. We run scalarize shapes after lowering to the backend contract (i.e., decomposing), and scalarize shapes is required to be able to infer dynamic dims coming from size int ops.
1 parent d0041dc commit 2665ed3

File tree

2 files changed

+234
-12
lines changed

2 files changed

+234
-12
lines changed

lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

Lines changed: 146 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,139 @@ class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
530530
none, none, none, none);
531531
return success();
532532
}
533+
auto squeezeOp = op.getSelf().getDefiningOp<AtenSqueezeDimOp>();
534+
if (squeezeOp && resultTy.getSizes().size() == 1) {
535+
rewriter.replaceOp(op, squeezeOp.getSelf());
536+
return success();
537+
}
533538

534539
return failure();
535540
}
536541
};
537542
} // namespace
543+
544+
namespace {
545+
// This is a specific pattern for converting views like [?,...,?,lastDim] ->
546+
// [?,...,?,factor0,factor1] to unflatten, and views like
547+
// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is
548+
// possible to infer that all but last shared dim match
549+
// TODO: move this to an actual canonicalizer for view after deleting the
550+
// conflicting decompositions for flatten/unflatten -> view.
551+
class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
552+
public:
553+
using OpRewritePattern<AtenViewOp>::OpRewritePattern;
554+
LogicalResult matchAndRewrite(AtenViewOp op,
555+
PatternRewriter &rewriter) const override {
556+
SmallVector<Value> viewSizes;
557+
if (failed(getListOperands(op.getSize(), viewSizes)))
558+
return rewriter.notifyMatchFailure(
559+
op, "view size must be from a list construct");
560+
auto selfTy = dyn_cast<Torch::ValueTensorType>(op.getSelf().getType());
561+
if (!selfTy || !selfTy.hasSizes())
562+
return rewriter.notifyMatchFailure(op, "missing input type or sizes");
563+
auto resultTy = dyn_cast<Torch::ValueTensorType>(op.getType());
564+
if (!resultTy || !resultTy.hasSizes() ||
565+
resultTy.getSizes().size() != viewSizes.size())
566+
return rewriter.notifyMatchFailure(op, "missing result type or sizes");
567+
int64_t inRank = selfTy.getSizes().size();
568+
int64_t outRank = resultTy.getSizes().size();
569+
570+
SmallVector<int64_t> sizes(selfTy.getSizes());
571+
int64_t endMatchingDim = -1;
572+
// input sizes vs. provided view sizes comparison loop
573+
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
574+
int64_t providedSize;
575+
bool providedStatic =
576+
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
577+
// if sizes[i] is static, it must match a constant in viewSizes[i]
578+
if (sizes[i] != Torch::kUnknownSize) {
579+
if (!providedStatic)
580+
return rewriter.notifyMatchFailure(
581+
op, "unsupported: found static input dim, but unable to match "
582+
"provided view size on a constant. See position : " +
583+
std::to_string(i));
584+
if (providedSize != sizes[i]) {
585+
endMatchingDim = i;
586+
break;
587+
}
588+
continue;
589+
}
590+
// the remaining assumes sizes[i] is dynamic
591+
// if provided dim is static, we can't verify it is a flatten/unflatten
592+
// unless -1
593+
if (i == outRank - 1 && providedStatic && providedSize == -1) {
594+
endMatchingDim = i;
595+
break;
596+
}
597+
if (providedStatic)
598+
return rewriter.notifyMatchFailure(
599+
op, "unexpected static view dim corresponding to dynamic input dim "
600+
"at position : " +
601+
std::to_string(i));
602+
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
603+
// if we don't have a size int op on self, fail
604+
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
605+
return rewriter.notifyMatchFailure(
606+
op, "expected dynamic view dim to come from a corresponding "
607+
"size.int op. See position : " +
608+
std::to_string(i));
609+
int64_t dim;
610+
// if the dim of the size int op doesn't match, fail
611+
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
612+
dim != i)
613+
return rewriter.notifyMatchFailure(
614+
op,
615+
"size int op dim cannot be matched to current dim at position : " +
616+
std::to_string(i));
617+
// passing the previous checks means viewSizes[i] = aten.size.int(self,
618+
// i), so continue
619+
}
620+
// if all dims match and the ranks are equal, fold
621+
if (endMatchingDim == -1 && inRank == outRank) {
622+
rewriter.replaceOp(op, op.getSelf());
623+
return success();
624+
}
625+
if (endMatchingDim > -1 && inRank > outRank) {
626+
// only support flattening last dim
627+
if (endMatchingDim != outRank - 1)
628+
return rewriter.notifyMatchFailure(
629+
op, "unimplemented: output has more than back dim mismatching");
630+
// flatten
631+
Value start =
632+
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
633+
Value end =
634+
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
635+
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
636+
op, resultTy, op.getSelf(), start, end);
637+
return success();
638+
}
639+
if (endMatchingDim > -1 && inRank < outRank) {
640+
// only support unflattening last dim
641+
if (endMatchingDim != inRank - 1)
642+
return rewriter.notifyMatchFailure(
643+
op, "unimplemented: input has more than back dim mismatching");
644+
// unflatten
645+
Value dim =
646+
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
647+
Value primList = rewriter.create<Torch::PrimListConstructOp>(
648+
op.getLoc(), op.getSize().getType(),
649+
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
650+
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
651+
op, resultTy, op.getSelf(), dim, primList);
652+
return success();
653+
}
654+
// examples that might reach this:
655+
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
656+
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
657+
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
658+
return rewriter.notifyMatchFailure(
659+
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
660+
", inRank=" + std::to_string(inRank) +
661+
", outRank=" + std::to_string(outRank));
662+
}
663+
};
664+
} // namespace
665+
538666
namespace {
539667
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
540668
public:
@@ -561,18 +689,24 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
561689
void runOnOperation() override {
562690
MLIRContext *context = &getContext();
563691
RewritePatternSet patterns(context);
564-
patterns
565-
.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
566-
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
567-
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
568-
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
569-
FoldAtenWhereSelf, RemoveUnusedPattern<Torch::AtenSizeIntOp>,
570-
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
571-
RemoveUnusedPattern<Torch::AtenTensorOp>,
572-
RemoveUnusedPattern<Torch::ConstantBoolOp>,
573-
RemoveUnusedPattern<Torch::ConstantIntOp>,
574-
RemoveUnusedPattern<Torch::ConstantNoneOp>,
575-
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context);
692+
patterns.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
693+
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
694+
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
695+
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
696+
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
697+
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
698+
RemoveUnusedPattern<Torch::AtenEqIntOp>,
699+
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
700+
RemoveUnusedPattern<Torch::AtenFullOp>,
701+
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
702+
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
703+
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
704+
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
705+
RemoveUnusedPattern<Torch::AtenTensorOp>,
706+
RemoveUnusedPattern<Torch::ConstantBoolOp>,
707+
RemoveUnusedPattern<Torch::ConstantIntOp>,
708+
RemoveUnusedPattern<Torch::ConstantNoneOp>,
709+
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context);
576710

577711
context->getLoadedDialect<mlir::arith::ArithDialect>()
578712
->getCanonicalizationPatterns(patterns);

test/Dialect/Torch/scalarize-shapes.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc
7272
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
7373
return %slice : !torch.vtensor<[2],si32>
7474
}
75+
76+
77+
// -----
78+
79+
// CHECK-LABEL: @view_as_flatten_static
80+
func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> {
81+
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
82+
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
83+
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
84+
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32>
85+
%int1024 = torch.constant.int 1024
86+
%int1 = torch.constant.int 1
87+
%int0 = torch.constant.int 0
88+
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
89+
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
90+
%2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
91+
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list<int> -> !torch.vtensor<[?,?,1024],f32>
92+
return %3 : !torch.vtensor<[?,?,1024],f32>
93+
}
94+
95+
96+
// -----
97+
98+
// CHECK-LABEL: @view_as_unflatten_static
99+
func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> {
100+
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
101+
// CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16
102+
// CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64
103+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list<int>
104+
// CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
105+
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32>
106+
%int16 = torch.constant.int 16
107+
%int64 = torch.constant.int 64
108+
%int1 = torch.constant.int 1
109+
%int0 = torch.constant.int 0
110+
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
111+
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
112+
%2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
113+
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
114+
return %3 : !torch.vtensor<[?,?,16,64],f32>
115+
}
116+
117+
118+
// -----
119+
120+
// CHECK-LABEL: @view_as_flatten_dynamic
121+
func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
122+
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
123+
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
124+
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
125+
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32>
126+
%int-1 = torch.constant.int -1
127+
%int1 = torch.constant.int 1
128+
%int0 = torch.constant.int 0
129+
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
130+
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
131+
%2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
132+
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
133+
return %3 : !torch.vtensor<[?,?,?],f32>
134+
}
135+
136+
137+
// -----
138+
139+
// CHECK-LABEL: @unsqueeze_squeeze_combo
140+
func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int {
141+
// CHECK: %int0 = torch.constant.int 0
142+
// CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
143+
// CHECK: return %0 : !torch.int
144+
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
145+
%1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
146+
%2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
147+
%int1 = torch.constant.int 1
148+
%int0 = torch.constant.int 0
149+
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
150+
%4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
151+
%5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
152+
%6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
153+
%7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
154+
%8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
155+
%9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
156+
%10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
157+
%11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
158+
%12 = torch.aten.cat %11, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
159+
%13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
160+
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
161+
return %14 : !torch.int
162+
}

0 commit comments

Comments
 (0)