Skip to content

Commit af81cbc

Browse files
authored
Merge branch 'main' into justinchu/torch-26
2 parents c73ae34 + 1426e9f commit af81cbc

File tree

5 files changed

+39
-46
lines changed

5 files changed

+39
-46
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
RealType,
4040
TFloat,
4141
TFloatHighPrecision,
42-
TFloatOrBFloat16,
4342
TInt,
4443
TReal,
4544
TRealOrUInt8,
@@ -3564,14 +3563,14 @@ def aten_flipud(self: TensorType) -> TensorType:
35643563

35653564

35663565
@torch_op("aten::floor", traceable=True)
3567-
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
3566+
def aten_floor(self: TFloat) -> TFloat:
35683567
"""floor(Tensor self) -> Tensor"""
35693568

35703569
return op.Floor(self)
35713570

35723571

35733572
@torch_op("math::floor", traceable=True)
3574-
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
3573+
def python_math_floor(self: TFloat) -> TInt:
35753574
"""floor(Tensor self) -> Tensor"""
35763575
floor = op.Floor(self)
35773576
return op.Cast(floor, to=INT64.dtype)
@@ -4533,7 +4532,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:
45334532

45344533

45354534
@torch_op("aten::isinf")
4536-
def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
4535+
def aten_isinf(self: TFloat) -> BOOL:
45374536
"""isinf(Tensor self) -> Tensor"""
45384537

45394538
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4542,14 +4541,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
45424541

45434542

45444543
@torch_op("aten::isnan")
4545-
def aten_isnan(self: TFloatOrBFloat16) -> BOOL:
4544+
def aten_isnan(self: TFloat) -> BOOL:
45464545
"""isnan(Tensor self) -> Tensor"""
45474546

45484547
return op.IsNaN(self)
45494548

45504549

45514550
@torch_op("aten::isneginf")
4552-
def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
4551+
def aten_isneginf(self: TFloat) -> BOOL:
45534552
"""isneginf(Tensor self) -> Tensor"""
45544553

45554554
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4558,7 +4557,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
45584557

45594558

45604559
@torch_op("aten::isposinf")
4561-
def aten_isposinf(self: TFloatOrBFloat16) -> BOOL:
4560+
def aten_isposinf(self: TFloat) -> BOOL:
45624561
"""isposinf(Tensor self) -> Tensor"""
45634562

45644563
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4778,42 +4777,42 @@ def aten_linspace(
47784777

47794778

47804779
@torch_op("aten::log", traceable=True)
4781-
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4780+
def aten_log(self: TFloat) -> TFloat:
47824781
"""log(Tensor self) -> Tensor"""
47834782

47844783
return op.Log(self)
47854784

47864785

47874786
@torch_op("aten::log10", traceable=True)
4788-
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4787+
def aten_log10(self: TFloat) -> TFloat:
47894788
"""log10(Tensor self) -> Tensor"""
47904789

47914790
return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self))
47924791

47934792

47944793
@torch_op("aten::log1p")
4795-
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4794+
def aten_log1p(self: TFloat) -> TFloat:
47964795
"""log1p(Tensor self) -> Tensor"""
47974796

47984797
return op.Log(op.Add(self, 1.0))
47994798

48004799

48014800
@torch_op("aten::log2", traceable=True)
4802-
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4801+
def aten_log2(self: TFloat) -> TFloat:
48034802
"""log2(Tensor self) -> Tensor"""
48044803

48054804
return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))
48064805

48074806

48084807
@torch_op("aten::logaddexp", traceable=True)
4809-
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
4808+
def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat:
48104809
"""logaddexp(Tensor self, Tensor other) -> Tensor"""
48114810

48124811
return op.Log(op.Add(op.Exp(self), op.Exp(other)))
48134812

48144813

48154814
@torch_op("aten::logaddexp2", traceable=True)
4816-
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
4815+
def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat:
48174816
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
48184817
two = op.CastLike(2.0, self)
48194818
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
@@ -4822,7 +4821,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr
48224821

48234822

48244823
@torch_op("aten::logcumsumexp", traceable=True)
4825-
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
4824+
def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat:
48264825
"""logcumsumexp(Tensor self, int dim) -> Tensor"""
48274826

48284827
if IsScalar(self):
@@ -4908,12 +4907,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
49084907

49094908

49104909
@torch_op("aten::logit", private=True)
4911-
def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4910+
def _aten_logit_onnx(self: TFloat) -> TFloat:
49124911
return op.Log(op.Div(self, op.Sub(1.0, self)))
49134912

49144913

49154914
@torch_op("aten::logit", private=True)
4916-
def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
4915+
def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat:
49174916
eps = op.CastLike(eps, self)
49184917
one = op.CastLike(1.0, self)
49194918
temporary_self = op.Where(self <= one - eps, self, one - eps)
@@ -4923,7 +4922,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat
49234922

49244923

49254924
@torch_op("aten::logit", trace_only=True)
4926-
def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16:
4925+
def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat:
49274926
"""logit(Tensor self, float? eps=None) -> Tensor"""
49284927
if eps is None:
49294928
return _aten_logit_onnx(self)
@@ -6041,9 +6040,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
60416040

60426041

60436042
@torch_op("aten::native_dropout", trace_only=True)
6044-
def aten_native_dropout(
6045-
input: TFloatOrBFloat16, p: float, train: bool = True
6046-
) -> Tuple[TFloatOrBFloat16, BOOL]:
6043+
def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]:
60476044
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""
60486045

60496046
result, mask = op.Dropout(input, p, train)
@@ -7055,7 +7052,7 @@ def aten_real(self: TensorType) -> TensorType:
70557052

70567053

70577054
@torch_op("aten::reciprocal")
7058-
def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7055+
def aten_reciprocal(self: TFloat) -> TFloat:
70597056
"""reciprocal(Tensor self) -> Tensor"""
70607057

70617058
return op.Reciprocal(self)
@@ -7074,7 +7071,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
70747071

70757072

70767073
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
7077-
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
7074+
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
70787075
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
70797076

70807077
# TODO(justinchuby): Improve fp16 precision by following the logic in
@@ -7355,7 +7352,7 @@ def aten_rrelu(
73557352

73567353

73577354
@torch_op("aten::rsqrt", traceable=True)
7358-
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7355+
def aten_rsqrt(self: TFloat) -> TFloat:
73597356
"""rsqrt(Tensor self) -> Tensor"""
73607357

73617358
return op.Reciprocal(op.Sqrt(self))
@@ -7562,7 +7559,7 @@ def aten_sgn(self: TensorType) -> TensorType:
75627559

75637560

75647561
@torch_op("aten::sigmoid", traceable=True)
7565-
def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7562+
def aten_sigmoid(self: TFloat) -> TFloat:
75667563
"""sigmoid(Tensor self) -> Tensor"""
75677564

75687565
return op.Sigmoid(self)
@@ -7724,7 +7721,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
77247721

77257722

77267723
@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True)
7727-
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
7724+
def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
77287725
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77297726

77307727
self_is_scalar = IsScalar(self)
@@ -7741,7 +7738,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB
77417738

77427739

77437740
@torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True)
7744-
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
7741+
def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat:
77457742
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77467743

77477744
self_is_scalar = IsScalar(self)
@@ -7812,7 +7809,7 @@ def aten_split_with_sizes_copy(
78127809

78137810

78147811
@torch_op("aten::sqrt", traceable=True)
7815-
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7812+
def aten_sqrt(self: TFloat) -> TFloat:
78167813
"""sqrt(Tensor self) -> Tensor"""
78177814

78187815
return op.Sqrt(self)
@@ -8402,7 +8399,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
84028399

84038400

84048401
@torch_op("aten::trunc")
8405-
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
8402+
def aten_trunc(self: TFloat) -> TFloat:
84068403
"""trunc(Tensor self) -> Tensor"""
84078404

84088405
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from onnxscript.function_libs.torch_lib.tensor_typing import (
2626
IntType,
2727
TFloat,
28-
TFloatOrBFloat16,
2928
TFloatOrUInt8,
3029
TInt,
3130
TReal,
@@ -364,13 +363,13 @@ def aten_conv_depthwise3d(
364363

365364
@torch_op("aten::cross_entropy_loss", traceable=True)
366365
def aten_cross_entropy_loss(
367-
self: TFloatOrBFloat16,
366+
self: TFloat,
368367
target: IntType,
369-
weight: Optional[TFloatOrBFloat16] = None,
368+
weight: Optional[TFloat] = None,
370369
reduction: int = 1, # default is 'mean'
371370
ignore_index: int = -100,
372371
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
373-
) -> TFloatOrBFloat16:
372+
) -> TFloat:
374373
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""
375374

376375
if reduction == 0: # "none"
@@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te
812811

813812

814813
@torch_op("aten::leaky_relu")
815-
def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16:
814+
def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
816815
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""
817816

818817
return op.LeakyRelu(self, alpha=negative_slope)
@@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:
850849

851850

852851
@torch_op("aten::log_sigmoid")
853-
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
852+
def aten_log_sigmoid(self: TFloat) -> TFloat:
854853
"""log_sigmoid(Tensor self) -> Tensor"""
855854

856855
return op.Log(op.Sigmoid(self))

onnxscript/function_libs/torch_lib/ops/special.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from onnxscript.function_libs.torch_lib.ops import common as common_ops
1919
from onnxscript.function_libs.torch_lib.registration import torch_op
20-
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16
20+
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
2121
from onnxscript.onnx_opset import opset18 as op
2222
from onnxscript.onnx_types import TensorType
2323

@@ -92,21 +92,21 @@ def aten_special_entr(self: TensorType) -> TensorType:
9292

9393

9494
@torch_op(("aten::erf", "aten::special_erf"))
95-
def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
95+
def aten_special_erf(self: TFloat) -> TFloat:
9696
"""erf(Tensor self) -> Tensor"""
9797

9898
return op.Erf(self)
9999

100100

101101
@torch_op(("aten::erfc", "aten::special_erfc"))
102-
def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
102+
def aten_special_erfc(self: TFloat) -> TFloat:
103103
"""erfc(Tensor self) -> Tensor"""
104104

105105
return op.Sub(1, op.Erf(self))
106106

107107

108108
@torch_op("aten::special_erfcx")
109-
def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
109+
def aten_special_erfcx(self: TFloat) -> TFloat:
110110
"""special_erfcx(Tensor self) -> Tensor"""
111111

112112
return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self)))
@@ -131,7 +131,7 @@ def aten_special_expit(self: TensorType) -> TensorType:
131131

132132

133133
@torch_op(("aten::expm1", "aten::special_expm1"))
134-
def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
134+
def aten_special_expm1(self: TFloat) -> TFloat:
135135
"""special_expm1(Tensor self) -> Tensor"""
136136

137137
return op.Sub(op.Exp(self), 1)
@@ -216,9 +216,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
216216

217217

218218
@torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True)
219-
def aten_special_log_softmax(
220-
self: TFloatOrBFloat16, dim: int, dtype: int = -1
221-
) -> TFloatOrBFloat16:
219+
def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
222220
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""
223221

224222
self_is_scalar = IsScalar(self)
@@ -366,7 +364,7 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType:
366364

367365

368366
@torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other"))
369-
def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
367+
def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat:
370368
"""special_xlogy(Tensor self, Tensor other) -> Tensor"""
371369

372370
# https://pytorch.org/docs/stable/special.html#torch.special.xlogy

onnxscript/function_libs/torch_lib/tensor_typing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
INT64,
4343
UINT8,
4444
]
45-
_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
45+
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
4646
IntType = Union[INT8, INT16, INT32, INT64]
4747
RealType = Union[
4848
BFLOAT16,
@@ -61,7 +61,6 @@
6161
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
6262
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
6363
TFloat = TypeVar("TFloat", bound=_FloatType)
64-
TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
6564
TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8])
6665
TInt = TypeVar("TInt", bound=IntType)
6766
TReal = TypeVar("TReal", bound=RealType)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
onnx-weekly==1.17.0.dev20240715
1+
onnx-weekly==1.18.0.dev20240930

0 commit comments

Comments
 (0)