@@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
4646};
4747} // namespace
4848
49+ namespace {
50+ class InferTensorOp : public OpRewritePattern <AtenTensorOp> {
51+ public:
52+ using OpRewritePattern::OpRewritePattern;
53+ LogicalResult matchAndRewrite (AtenTensorOp op,
54+ PatternRewriter &rewriter) const override {
55+ auto context = op.getContext ();
56+ auto loc = op.getLoc ();
57+ auto result = op.getResult ();
58+ auto resultType = cast<BaseTensorType>(result.getType ());
59+ if (resultType.hasSizes () && resultType.hasDtype ()) {
60+ return rewriter.notifyMatchFailure (
61+ op, " The result of aten.tensor is already a BaseTensorType." );
62+ }
63+
64+ auto inputList = op.getOperand (0 );
65+ auto listConstruct = inputList.getDefiningOp <PrimListConstructOp>();
66+ if (!listConstruct) {
67+ return rewriter.notifyMatchFailure (
68+ op, " The operand 0 of aten.tensor is not PrimListConstructOp." );
69+ }
70+
71+ // Currently only support the 1d input list.
72+ SmallVector<int64_t > sizes;
73+ sizes.push_back (listConstruct->getOperands ().size ());
74+ FailureOr<Type> torchType;
75+ auto eleType = listConstruct->getOperands ()[0 ].getType ();
76+ if (isa<Torch::IntType>(eleType)) {
77+ torchType = getTypeForScalarType (op->getContext (),
78+ torch_upstream::ScalarType::Long);
79+ } else if (isa<Torch::FloatType>(eleType)) {
80+ torchType = getTypeForScalarType (op->getContext (),
81+ torch_upstream::ScalarType::Float);
82+ } else {
83+ return rewriter.notifyMatchFailure (
84+ op, " Currently only support Int and Float Type." );
85+ }
86+ auto newResultType = ValueTensorType::get (context, sizes, *torchType);
87+
88+ Value originalTypedValue;
89+ for (OpOperand &use : llvm::make_early_inc_range (result.getUses ())) {
90+ if (!originalTypedValue) {
91+ rewriter.setInsertionPointAfter (op);
92+ originalTypedValue =
93+ rewriter.create <TensorStaticInfoCastOp>(loc, resultType, result);
94+ }
95+ use.set (originalTypedValue);
96+ }
97+
98+ result.setType (newResultType);
99+
100+ return success ();
101+ }
102+ };
103+ } // namespace
104+
49105static LogicalResult refineShapeCalculateResult (ShapeCalculateOp op,
50106 int resultNum,
51107 PatternRewriter &rewriter) {
@@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
135191 populateFoldPrimUncheckedCastOpPattern (patterns, context);
136192 patterns.insert <DecomposeAtenSizeOp>(context);
137193 patterns.insert <RefineShapeCalculateOp>(context);
194+ patterns.insert <InferTensorOp>(context);
138195
139196 PrimIfOp::getCanonicalizationPatterns (patterns, context);
140197 Aten__Getitem__TOp::getCanonicalizationPatterns (patterns, context);
0 commit comments