Skip to content

Commit b176939

Browse files
authored
[Torch] support 1d aten tensor shape and dtype infer (#3776)
1 parent ab62f35 commit b176939

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
4646
};
4747
} // namespace
4848

49+
namespace {
50+
class InferTensorOp : public OpRewritePattern<AtenTensorOp> {
51+
public:
52+
using OpRewritePattern::OpRewritePattern;
53+
LogicalResult matchAndRewrite(AtenTensorOp op,
54+
PatternRewriter &rewriter) const override {
55+
auto context = op.getContext();
56+
auto loc = op.getLoc();
57+
auto result = op.getResult();
58+
auto resultType = cast<BaseTensorType>(result.getType());
59+
if (resultType.hasSizes() && resultType.hasDtype()) {
60+
return rewriter.notifyMatchFailure(
61+
op, "The result of aten.tensor is already a BaseTensorType.");
62+
}
63+
64+
auto inputList = op.getOperand(0);
65+
auto listConstruct = inputList.getDefiningOp<PrimListConstructOp>();
66+
if (!listConstruct) {
67+
return rewriter.notifyMatchFailure(
68+
op, "The operand 0 of aten.tensor is not PrimListConstructOp.");
69+
}
70+
71+
// Currently only support the 1d input list.
72+
SmallVector<int64_t> sizes;
73+
sizes.push_back(listConstruct->getOperands().size());
74+
FailureOr<Type> torchType;
75+
auto eleType = listConstruct->getOperands()[0].getType();
76+
if (isa<Torch::IntType>(eleType)) {
77+
torchType = getTypeForScalarType(op->getContext(),
78+
torch_upstream::ScalarType::Long);
79+
} else if (isa<Torch::FloatType>(eleType)) {
80+
torchType = getTypeForScalarType(op->getContext(),
81+
torch_upstream::ScalarType::Float);
82+
} else {
83+
return rewriter.notifyMatchFailure(
84+
op, "Currently only support Int and Float Type.");
85+
}
86+
auto newResultType = ValueTensorType::get(context, sizes, *torchType);
87+
88+
Value originalTypedValue;
89+
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
90+
if (!originalTypedValue) {
91+
rewriter.setInsertionPointAfter(op);
92+
originalTypedValue =
93+
rewriter.create<TensorStaticInfoCastOp>(loc, resultType, result);
94+
}
95+
use.set(originalTypedValue);
96+
}
97+
98+
result.setType(newResultType);
99+
100+
return success();
101+
}
102+
};
103+
} // namespace
104+
49105
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
50106
int resultNum,
51107
PatternRewriter &rewriter) {
@@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
135191
populateFoldPrimUncheckedCastOpPattern(patterns, context);
136192
patterns.insert<DecomposeAtenSizeOp>(context);
137193
patterns.insert<RefineShapeCalculateOp>(context);
194+
patterns.insert<InferTensorOp>(context);
138195

139196
PrimIfOp::getCanonicalizationPatterns(patterns, context);
140197
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5621,6 +5621,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
56215621
# ==============================================================================
56225622

56235623

5624+
class TensorAlloc1dStaticModule(torch.nn.Module):
5625+
def __init__(self):
5626+
super().__init__()
5627+
5628+
@export
5629+
@annotate_args(
5630+
[
5631+
None,
5632+
([2, 4, 6], torch.int, True),
5633+
]
5634+
)
5635+
def forward(self, x):
5636+
res = torch.tensor([x.shape[0]])
5637+
return res
5638+
5639+
5640+
@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule())
5641+
def TensorAlloc1dStaticModule_basic(module, tu: TestUtils):
5642+
module.forward(tu.rand(2, 4, 6))
5643+
5644+
5645+
# ==============================================================================
5646+
5647+
56245648
class ScalarTensorFloat32Module(torch.nn.Module):
56255649
def __init__(self):
56265650
super().__init__()

0 commit comments

Comments
 (0)