diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e088b887f..2ec3b8f20 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -148,6 +148,11 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: return op.Add(self, other) +@torch_op(("_operator::add"), trace_only=True) +def operator_add(self: TTensor, other: TTensor) -> TTensor: + return op.Add(self, other) + + @torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True) def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -5567,7 +5572,7 @@ def aten_msort(self: TensorType) -> TensorType: @torch_op( - ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), + ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), trace_only=True, ) def aten_mul(self: TTensor, other: TTensor) -> TTensor: @@ -5579,6 +5584,11 @@ def aten_mul(self: TTensor, other: TTensor) -> TTensor: return op.Mul(self, other) +@torch_op("_operator::mul", trace_only=True) +def operator_mul(self: TTensor, other: TTensor) -> TTensor: + return op.Mul(self, other) + + @torch_op( ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), trace_only=True,