Skip to content

Commit d68d652

Browse files
authored
[torchlib] Mark a few ops as traceable (#1889)
- pow - sqrt - rsqrt - round
1 parent 35fdcf5 commit d68d652

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6619,7 +6619,8 @@ def aten_positive(self: TensorType) -> TensorType:
66196619
"aten::pow.Tensor_Tensor",
66206620
"aten::pow.Tensor_Scalar",
66216621
"_operator::pow",
6622-
)
6622+
),
6623+
traceable=True,
66236624
)
66246625
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
66256626
"""pow(Tensor self, Tensor exponent) -> Tensor"""
@@ -7304,7 +7305,7 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
73047305
raise NotImplementedError()
73057306

73067307

7307-
@torch_op("aten::round")
7308+
@torch_op("aten::round", traceable=True)
73087309
def aten_round(self: TFloat) -> TFloat:
73097310
"""round(Tensor self) -> Tensor"""
73107311

@@ -7353,7 +7354,7 @@ def aten_rrelu(
73537354
raise NotImplementedError()
73547355

73557356

7356-
@torch_op("aten::rsqrt")
7357+
@torch_op("aten::rsqrt", traceable=True)
73577358
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
73587359
"""rsqrt(Tensor self) -> Tensor"""
73597360

@@ -7810,7 +7811,7 @@ def aten_split_with_sizes_copy(
78107811
raise NotImplementedError()
78117812

78127813

7813-
@torch_op("aten::sqrt")
7814+
@torch_op("aten::sqrt", traceable=True)
78147815
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
78157816
"""sqrt(Tensor self) -> Tensor"""
78167817

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,6 @@ def _where_input_wrangler(
13491349
.xfail(
13501350
variant_name="decimals_0",
13511351
reason="This variant does not accept decimals",
1352-
test_class_name="TestOutputConsistencyEager",
13531352
)
13541353
.xfail(
13551354
variant_name="decimals_3",

0 commit comments

Comments
 (0)