Skip to content

Commit 1544ee1

Browse files
authored
[torchlib] Do not register rsub (#1907)
Remove rsub since it is handled by decomp, and torch doesn't have a type promotion rule for rsub so we use sub instead. Tested with ```python import torch class Model(torch.nn.Module): def forward(self, x): return 1 - x ep = torch.export.export(Model(), (torch.tensor(1),)) print(ep) program = torch.onnx.export(Model(), (torch.tensor(1),), dynamo=True) print(program) ```
1 parent a4f3bcb commit 1544ee1

File tree

2 files changed

+2
-11
lines changed

2 files changed

+2
-11
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7352,18 +7352,11 @@ def aten_rsqrt(self: TFloat) -> TFloat:
73527352
return op.Reciprocal(op.Sqrt(self))
73537353

73547354

7355-
@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar"))
7355+
# Do not register rsub. It will be decomposed and type promoted by torch
73567356
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
73577357
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
73587358

7359-
return op.Sub(other, op.Mul(self, alpha))
7360-
7361-
7362-
@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar"), trace_only=True, complex=True)
7363-
def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
7364-
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
7365-
7366-
return aten_rsub(self, other, alpha)
7359+
raise NotImplementedError
73677360

73687361

73697362
@torch_op("aten::scalar_tensor", trace_only=True)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,8 +1360,6 @@ def _where_input_wrangler(
13601360
),
13611361
TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals),
13621362
TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt),
1363-
TorchLibOpInfo("rsub", core_ops.aten_rsub),
1364-
TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True),
13651363
TorchLibOpInfo(
13661364
"scalar_tensor",
13671365
core_ops.aten_scalar_tensor,

0 commit comments

Comments
 (0)