Skip to content

Commit f34eb66

Browse files
Shukla-GauravGaurav Shukla
authored andcommitted
[TORCH][MLIR] Add E2E support for [aten.gt.Scalar|aten.where.self]
This commit adds lowering of `aten.gt.Scalar` and `aten.where.self` as a part of element-wise ops lowering. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 2414bdb commit f34eb66

File tree

5 files changed

+124
-5
lines changed

5 files changed

+124
-5
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,29 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils):
8585
# ==============================================================================
8686

8787

88+
class ElementwiseWhereSelfModule(torch.nn.Module):
89+
def __init__(self):
90+
super().__init__()
91+
92+
@export
93+
@annotate_args([
94+
None,
95+
([-1, -1, -1], torch.float32, True),
96+
([-1, -1], torch.float32, True),
97+
([-1], torch.float32, True),
98+
])
99+
def forward(self, a, b, c):
100+
return torch.where(a > 0.5, b, c)
101+
102+
103+
@register_test_case(module_factory=lambda: ElementwiseWhereSelfModule())
104+
def ElementwiseWhereSelfModule_basic(module, tu: TestUtils):
105+
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))
106+
107+
108+
# ==============================================================================
109+
110+
88111
# Addition is an interesting special case of a binary op, because under the hood
89112
# it carries a third scalar "alpha" parameter, which needs special handling.
90113
class ElementwiseAddModule(torch.nn.Module):
@@ -303,6 +326,26 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils):
303326
# ==============================================================================
304327

305328

329+
class ElementwiseGtScalarModule(torch.nn.Module):
330+
def __init__(self):
331+
super().__init__()
332+
333+
@export
334+
@annotate_args([
335+
None,
336+
([-1, -1], torch.float32, True),
337+
])
338+
def forward(self, x):
339+
return torch.gt(x, 0.6)
340+
341+
342+
@register_test_case(module_factory=lambda: ElementwiseGtScalarModule())
343+
def ElementwiseGtScalarModule_basic(module, tu: TestUtils):
344+
module.forward(tu.rand(3, 5))
345+
346+
# ==============================================================================
347+
348+
306349
class ElementwiseClampModule(torch.nn.Module):
307350
def __init__(self):
308351
super().__init__()

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,22 @@ def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
10711071
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
10721072
}
10731073

1074+
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
1075+
AllowsTypeRefinement,
1076+
HasValueSemantics
1077+
]> {
1078+
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
1079+
let arguments = (ins
1080+
AnyTorchTensorType:$condition,
1081+
AnyTorchTensorType:$self,
1082+
AnyTorchTensorType:$other
1083+
);
1084+
let results = (outs
1085+
AnyTorchTensorType:$result
1086+
);
1087+
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
1088+
}
1089+
10741090
def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
10751091
AllowsTypeRefinement,
10761092
HasValueSemantics

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,6 +1684,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
16841684
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
16851685
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
16861686
}
1687+
1688+
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
1689+
Type dtype = gtScalar.self().getType().cast<ValueTensorType>().getDtype();
1690+
if (!dtype.isa<mlir::FloatType>()) {
1691+
gtScalar.emitError("unimplemented: non-floating point operand dtype");
1692+
return nullptr;
1693+
}
1694+
Value otherPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
1695+
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
1696+
payloadArgs[0], otherPromoted);
1697+
}
1698+
1699+
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
1700+
Type dtype = converter->convertType(whereSelf.getType())
1701+
.cast<RankedTensorType>()
1702+
.getElementType();
1703+
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
1704+
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
1705+
return b.create<SelectOp>(loc, payloadArgs[0], lhs, rhs);
1706+
}
1707+
16871708
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
16881709
if (!lerp.getType()
16891710
.cast<ValueTensorType>()
@@ -2040,7 +2061,7 @@ struct ConvertElementwiseOp : ConversionPattern {
20402061
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
20412062
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
20422063
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
2043-
AtenBitwiseAndTensorOp>(op))
2064+
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(op))
20442065
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
20452066

20462067
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@@ -3461,7 +3482,8 @@ class ConvertTorchToLinalg
34613482
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
34623483
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
34633484
AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
3464-
AtenReciprocalOp, AtenBitwiseAndTensorOp>();
3485+
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
3486+
AtenWhereSelfOp>();
34653487
patterns.add<ConvertElementwiseOp>(typeConverter, context);
34663488
target.addIllegalOp<AtenSqueezeOp>();
34673489
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
235235
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
236236
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
237237
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
238-
AtenGeluBackwardOp, AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp,
239-
AtenNeScalarOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp,
240-
AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
238+
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
239+
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
241240
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
242241
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
243242
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
@@ -247,6 +246,20 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
247246
return getLatticeElement(op->getResult(0)).join(*operands[0]);
248247
}
249248

249+
// These comparison ops return a tensor with 1-bit integer dtype.
250+
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp>(
251+
op)) {
252+
auto operand = operands[0]->getValue();
253+
auto knowledge =
254+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
255+
if (operand.hasSizes) {
256+
knowledge.hasSizes = true;
257+
knowledge.sizes = operand.sizes;
258+
}
259+
knowledge.dtype = IntegerType::get(op->getContext(), 1);
260+
return getLatticeElement(op->getResult(0)).join(knowledge);
261+
}
262+
250263
// Resize to [1, 1] with integer dtype.
251264
if (isa<AtenAnyOp, AtenAllOp>(op)) {
252265
auto input = operands[0]->getValue();
@@ -307,6 +320,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
307320
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
308321
AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
309322
return visitBinaryBroadcastingOp(op, operands);
323+
} else if (auto whereSelf = llvm::dyn_cast<AtenWhereSelfOp>(op)) {
324+
return visitAtenWhereSelfOp(whereSelf, operands);
310325
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
311326
return visitAtenLerpTensorOp(lerpTensor, operands);
312327
} else if (auto flatten = dyn_cast<AtenFlattenUsingIntsOp>(op)) {
@@ -487,6 +502,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
487502
ChangeResult visitBinaryBroadcastingOp(
488503
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
489504
ChangeResult
505+
visitAtenWhereSelfOp(AtenWhereSelfOp op,
506+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
507+
ChangeResult
490508
visitAtenLerpTensorOp(AtenLerpTensorOp op,
491509
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
492510
ChangeResult visitAtenFlattenUsingIntsOp(
@@ -856,6 +874,25 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp(
856874
return getLatticeElement(op->getResult(0)).join(knowledge);
857875
}
858876

877+
ChangeResult TypeAnalyzer::visitAtenWhereSelfOp(
878+
AtenWhereSelfOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
879+
auto condition = operands[0]->getValue();
880+
auto lhs = operands[1]->getValue();
881+
auto rhs = operands[2]->getValue();
882+
auto knowledge =
883+
ValueKnowledge::getNotNonePessimisticValueState(getContext());
884+
if (condition.hasSizes && lhs.hasSizes && rhs.hasSizes) {
885+
knowledge.hasSizes = true;
886+
knowledge.sizes.resize(
887+
std::max(condition.sizes.size(),
888+
std::max(lhs.sizes.size(), rhs.sizes.size())),
889+
kUnknownSize);
890+
}
891+
892+
knowledge.dtype = getPromotedResultType(getContext(), {&lhs, &rhs});
893+
return getLatticeElement(op->getResult(0)).join(knowledge);
894+
}
895+
859896
ChangeResult TypeAnalyzer::visitAtenLerpTensorOp(
860897
AtenLerpTensorOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
861898
// This is a general broadcasting shape transfer function.

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def emit_with_mutating_variants(key, **kwargs):
479479
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
480480
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
481481
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
482+
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
482483
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
483484
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
484485
emit("aten::gelu : (Tensor) -> (Tensor)")

0 commit comments

Comments
 (0)