Skip to content

Commit 3cb46ce

Browse files
Added aten::t() Op
1 parent 5eed562 commit 3cb46ce

File tree

7 files changed

+125
-2
lines changed

7 files changed

+125
-2
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,3 +1179,56 @@ def forward(self, a):
11791179
@register_test_case(module_factory=lambda: BoolTensorReturnMixedModule())
11801180
def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
11811181
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
1182+
1183+
# ==============================================================================
1184+
class TModuleRank2(torch.nn.Module):
1185+
def __init__(self):
1186+
super().__init__()
1187+
1188+
@export
1189+
@annotate_args([
1190+
None,
1191+
([-1, -1], torch.float32, True),
1192+
])
1193+
def forward(self, lhs):
1194+
return torch.t(lhs)
1195+
1196+
1197+
@register_test_case(module_factory=lambda: TModuleRank2())
1198+
def TModuleRank2_basic(module, tu: TestUtils):
1199+
module.forward(tu.rand(3, 4))
1200+
1201+
class TModuleRank1(torch.nn.Module):
1202+
def __init__(self):
1203+
super().__init__()
1204+
1205+
@export
1206+
@annotate_args([
1207+
None,
1208+
([-1], torch.float32, True),
1209+
])
1210+
def forward(self, lhs):
1211+
return torch.t(lhs)
1212+
1213+
1214+
@register_test_case(module_factory=lambda: TModuleRank1())
1215+
def TModuleRank1_basic(module, tu: TestUtils):
1216+
module.forward(tu.rand(3))
1217+
1218+
class TModuleRank0(torch.nn.Module):
1219+
def __init__(self):
1220+
super().__init__()
1221+
1222+
@export
1223+
@annotate_args([
1224+
None,
1225+
([], torch.float32, True),
1226+
])
1227+
def forward(self, lhs):
1228+
return torch.t(lhs)
1229+
1230+
1231+
@register_test_case(module_factory=lambda: TModuleRank0())
1232+
def TModuleRank0_basic(module, tu: TestUtils):
1233+
module.forward(torch.tensor(7, dtype=torch.float32))
1234+

e2e_testing/torchscript/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@
4141
"SqueezeModule_static",
4242
"SqueezeModule_noUnitDim",
4343
"SqueezeModule_allUnitDim",
44+
"TModuleRank1_basic",
45+
"TModuleRank0_basic",
4446
}

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2603,6 +2603,19 @@ def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
26032603
let assemblyFormat = "$input `,` $p `,` $train attr-dict `:` type($input) `,` type($p) `,` type($train) `->` type($result)";
26042604
}
26052605

2606+
def Torch_AtenTOp : Torch_Op<"aten.t", [
2607+
AllowsTypeRefinement
2608+
]> {
2609+
let summary = "Generated op for `aten::t : (Tensor) -> (Tensor)`";
2610+
let arguments = (ins
2611+
AnyTorchTensorType:$self
2612+
);
2613+
let results = (outs
2614+
AnyTorchTensorType:$result
2615+
);
2616+
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
2617+
}
2618+
26062619
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
26072620
AllowsTypeRefinement,
26082621
HasValueSemantics

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,36 @@ class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
379379
};
380380
} // namespace
381381

382+
namespace {
383+
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
384+
public:
385+
using OpRewritePattern::OpRewritePattern;
386+
LogicalResult matchAndRewrite(AtenTOp op,
387+
PatternRewriter &rewriter) const override {
388+
Value lhs = op.self();
389+
int lhsRank = getTensorRank(lhs);
390+
auto loc = op.getLoc();
391+
392+
if (lhsRank > 2 || lhsRank < 0) {
393+
std::string errorMessage =
394+
"t() expects a tensor with <=2 dimensions, but self is " +
395+
std::to_string(lhsRank) + "D";
396+
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
397+
} else if (lhsRank < 2)
398+
rewriter.replaceOp(op, lhs);
399+
else {
400+
Value zero =
401+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
402+
Value one =
403+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
404+
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
405+
zero, one);
406+
}
407+
return success();
408+
}
409+
};
410+
} // namespace
411+
382412
// Decompose torch.expand into torch.broadcast_to op.
383413
namespace {
384414
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
@@ -565,6 +595,8 @@ class DecomposeComplexOpsPass
565595
patterns.add<DecomposeAtenSelectIntOp>(context);
566596
target.addIllegalOp<AtenSelectIntOp>();
567597
patterns.add<DecomposeAtenMatmulOp>(context);
598+
target.addIllegalOp<AtenTOp>();
599+
patterns.add<DecomposeAtenTOp>(context);
568600
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
569601
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
570602
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {

lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ class RewriteViewLikeSubgraph
9393
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
9494
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
9595
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
96-
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp>(
97-
op)) {
96+
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
97+
AtenTOp>(op)) {
9898
// AtenContiguousOp might return a view, so this is conservatively
9999
// correct. We could potentially be more precise and identify the cases
100100
// that it does not return a view and treat those as having value

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
374374
return visitReshapeLikeOp(resize, operands);
375375
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
376376
return visitAtenTransposeIntOp(transposeInt, operands);
377+
} else if (auto t = dyn_cast<AtenTOp>(op)) {
378+
return visitAtenTOp(t, operands);
377379
} else if (auto permute = dyn_cast<AtenPermuteOp>(op)) {
378380
return visitAtenPermuteOp(permute, operands);
379381
} else if (auto tensorFloat = dyn_cast<AtenTensorFloatOp>(op)) {
@@ -550,6 +552,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
550552
visitAtenTransposeIntOp(AtenTransposeIntOp op,
551553
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
552554
ChangeResult
555+
visitAtenTOp(AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
556+
ChangeResult
553557
visitAtenPermuteOp(AtenPermuteOp op,
554558
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
555559
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
@@ -1242,6 +1246,24 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
12421246
return getLatticeElement(op.getResult()).join(knowledge);
12431247
}
12441248

1249+
ChangeResult TypeAnalyzer::visitAtenTOp(
1250+
AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1251+
auto input = operands[0]->getValue();
1252+
auto knowledge =
1253+
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
1254+
knowledge.dtype = input.dtype;
1255+
if (!input.hasSizes)
1256+
return getLatticeElement(op.getResult()).join(knowledge);
1257+
int64_t inputRank = input.sizes.size();
1258+
if (inputRank >= 0 && inputRank <= 2) {
1259+
knowledge.hasSizes = input.hasSizes;
1260+
knowledge.sizes = input.sizes;
1261+
if (inputRank == 2)
1262+
std::swap(knowledge.sizes[0], knowledge.sizes[1]);
1263+
}
1264+
return getLatticeElement(op.getResult()).join(knowledge);
1265+
}
1266+
12451267
ChangeResult TypeAnalyzer::visitAtenPermuteOp(
12461268
AtenPermuteOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
12471269
auto input = operands[0]->getValue();

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def emit_with_mutating_variants(key, **kwargs):
594594
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
595595
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
596596
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
597+
emit("aten::t : (Tensor) -> (Tensor)")
597598

598599
# Dict ops.
599600
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)

0 commit comments

Comments
 (0)