3939 RealType ,
4040 TFloat ,
4141 TFloatHighPrecision ,
42- TFloatOrBFloat16 ,
4342 TInt ,
4443 TReal ,
4544 TRealOrUInt8 ,
@@ -3564,14 +3563,14 @@ def aten_flipud(self: TensorType) -> TensorType:
35643563
35653564
35663565@torch_op ("aten::floor" , traceable = True )
3567- def aten_floor (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
3566+ def aten_floor (self : TFloat ) -> TFloat :
35683567 """floor(Tensor self) -> Tensor"""
35693568
35703569 return op .Floor (self )
35713570
35723571
35733572@torch_op ("math::floor" , traceable = True )
3574- def python_math_floor (self : TFloatOrBFloat16 ) -> TInt :
3573+ def python_math_floor (self : TFloat ) -> TInt :
35753574 """floor(Tensor self) -> Tensor"""
35763575 floor = op .Floor (self )
35773576 return op .Cast (floor , to = INT64 .dtype )
@@ -4533,7 +4532,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:
45334532
45344533
45354534@torch_op ("aten::isinf" )
4536- def aten_isinf (self : TFloatOrBFloat16 ) -> BOOL :
4535+ def aten_isinf (self : TFloat ) -> BOOL :
45374536 """isinf(Tensor self) -> Tensor"""
45384537
45394538 # Added Cast inside the function so it can support all real dtypes naturally
@@ -4542,14 +4541,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
45424541
45434542
45444543@torch_op ("aten::isnan" )
4545- def aten_isnan (self : TFloatOrBFloat16 ) -> BOOL :
4544+ def aten_isnan (self : TFloat ) -> BOOL :
45464545 """isnan(Tensor self) -> Tensor"""
45474546
45484547 return op .IsNaN (self )
45494548
45504549
45514550@torch_op ("aten::isneginf" )
4552- def aten_isneginf (self : TFloatOrBFloat16 ) -> BOOL :
4551+ def aten_isneginf (self : TFloat ) -> BOOL :
45534552 """isneginf(Tensor self) -> Tensor"""
45544553
45554554 # Added Cast inside the function so it can support all real dtypes naturally
@@ -4558,7 +4557,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
45584557
45594558
45604559@torch_op ("aten::isposinf" )
4561- def aten_isposinf (self : TFloatOrBFloat16 ) -> BOOL :
4560+ def aten_isposinf (self : TFloat ) -> BOOL :
45624561 """isposinf(Tensor self) -> Tensor"""
45634562
45644563 # Added Cast inside the function so it can support all real dtypes naturally
@@ -4778,42 +4777,42 @@ def aten_linspace(
47784777
47794778
47804779@torch_op ("aten::log" , traceable = True )
4781- def aten_log (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4780+ def aten_log (self : TFloat ) -> TFloat :
47824781 """log(Tensor self) -> Tensor"""
47834782
47844783 return op .Log (self )
47854784
47864785
47874786@torch_op ("aten::log10" , traceable = True )
4788- def aten_log10 (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4787+ def aten_log10 (self : TFloat ) -> TFloat :
47894788 """log10(Tensor self) -> Tensor"""
47904789
47914790 return op .Div (op .Log (self ), op .CastLike (op .Log (10.0 ), self ))
47924791
47934792
47944793@torch_op ("aten::log1p" )
4795- def aten_log1p (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4794+ def aten_log1p (self : TFloat ) -> TFloat :
47964795 """log1p(Tensor self) -> Tensor"""
47974796
47984797 return op .Log (op .Add (self , 1.0 ))
47994798
48004799
48014800@torch_op ("aten::log2" , traceable = True )
4802- def aten_log2 (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4801+ def aten_log2 (self : TFloat ) -> TFloat :
48034802 """log2(Tensor self) -> Tensor"""
48044803
48054804 return op .Div (op .Log (self ), op .CastLike (op .Log (2.0 ), self ))
48064805
48074806
48084807@torch_op ("aten::logaddexp" , traceable = True )
4809- def aten_logaddexp (self : TFloatOrBFloat16 , other : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4808+ def aten_logaddexp (self : TFloat , other : TFloat ) -> TFloat :
48104809 """logaddexp(Tensor self, Tensor other) -> Tensor"""
48114810
48124811 return op .Log (op .Add (op .Exp (self ), op .Exp (other )))
48134812
48144813
48154814@torch_op ("aten::logaddexp2" , traceable = True )
4816- def aten_logaddexp2 (self : TFloatOrBFloat16 , other : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4815+ def aten_logaddexp2 (self : TFloat , other : TFloat ) -> TFloat :
48174816 """logaddexp2(Tensor self, Tensor other) -> Tensor"""
48184817 two = op .CastLike (2.0 , self )
48194818 summation = op .Add (op .Pow (two , self ), op .Pow (two , other ))
@@ -4822,7 +4821,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr
48224821
48234822
48244823@torch_op ("aten::logcumsumexp" , traceable = True )
4825- def aten_logcumsumexp (self : TFloatOrBFloat16 , dim : int ) -> TFloatOrBFloat16 :
4824+ def aten_logcumsumexp (self : TFloat , dim : int ) -> TFloat :
48264825 """logcumsumexp(Tensor self, int dim) -> Tensor"""
48274826
48284827 if IsScalar (self ):
@@ -4908,12 +4907,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
49084907
49094908
49104909@torch_op ("aten::logit" , private = True )
4911- def _aten_logit_onnx (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4910+ def _aten_logit_onnx (self : TFloat ) -> TFloat :
49124911 return op .Log (op .Div (self , op .Sub (1.0 , self )))
49134912
49144913
49154914@torch_op ("aten::logit" , private = True )
4916- def _aten_logit_clamp_onnx (self : TFloatOrBFloat16 , eps : float ) -> TFloatOrBFloat16 :
4915+ def _aten_logit_clamp_onnx (self : TFloat , eps : float ) -> TFloat :
49174916 eps = op .CastLike (eps , self )
49184917 one = op .CastLike (1.0 , self )
49194918 temporary_self = op .Where (self <= one - eps , self , one - eps )
@@ -4923,7 +4922,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat
49234922
49244923
49254924@torch_op ("aten::logit" , trace_only = True )
4926- def aten_logit (self : TFloatOrBFloat16 , eps : Optional [float ] = None ) -> TFloatOrBFloat16 :
4925+ def aten_logit (self : TFloat , eps : Optional [float ] = None ) -> TFloat :
49274926 """logit(Tensor self, float? eps=None) -> Tensor"""
49284927 if eps is None :
49294928 return _aten_logit_onnx (self )
@@ -6041,9 +6040,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
60416040
60426041
60436042@torch_op ("aten::native_dropout" , trace_only = True )
6044- def aten_native_dropout (
6045- input : TFloatOrBFloat16 , p : float , train : bool = True
6046- ) -> Tuple [TFloatOrBFloat16 , BOOL ]:
6043+ def aten_native_dropout (input : TFloat , p : float , train : bool = True ) -> Tuple [TFloat , BOOL ]:
60476044 """native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""
60486045
60496046 result , mask = op .Dropout (input , p , train )
@@ -7055,7 +7052,7 @@ def aten_real(self: TensorType) -> TensorType:
70557052
70567053
70577054@torch_op ("aten::reciprocal" )
7058- def aten_reciprocal (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7055+ def aten_reciprocal (self : TFloat ) -> TFloat :
70597056 """reciprocal(Tensor self) -> Tensor"""
70607057
70617058 return op .Reciprocal (self )
@@ -7074,7 +7071,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
70747071
70757072
70767073@torch_op (("aten::remainder.Tensor" , "aten::remainder.Scalar" ))
7077- def aten_remainder (self : TFloatOrBFloat16 , other : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7074+ def aten_remainder (self : TFloat , other : TFloat ) -> TFloat :
70787075 """remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
70797076
70807077 # TODO(justinchuby): Improve fp16 precision by following the logic in
@@ -7355,7 +7352,7 @@ def aten_rrelu(
73557352
73567353
73577354@torch_op ("aten::rsqrt" , traceable = True )
7358- def aten_rsqrt (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7355+ def aten_rsqrt (self : TFloat ) -> TFloat :
73597356 """rsqrt(Tensor self) -> Tensor"""
73607357
73617358 return op .Reciprocal (op .Sqrt (self ))
@@ -7562,7 +7559,7 @@ def aten_sgn(self: TensorType) -> TensorType:
75627559
75637560
75647561@torch_op ("aten::sigmoid" , traceable = True )
7565- def aten_sigmoid (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7562+ def aten_sigmoid (self : TFloat ) -> TFloat :
75667563 """sigmoid(Tensor self) -> Tensor"""
75677564
75687565 return op .Sigmoid (self )
@@ -7724,7 +7721,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
77247721
77257722
77267723@torch_op (("aten::softmax.int" , "aten::special_softmax" ), trace_only = True )
7727- def aten_softmax (self : TFloatOrBFloat16 , dim : int , dtype : int = - 1 ) -> TFloatOrBFloat16 :
7724+ def aten_softmax (self : TFloat , dim : int , dtype : int = - 1 ) -> TFloat :
77287725 """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77297726
77307727 self_is_scalar = IsScalar (self )
@@ -7741,7 +7738,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB
77417738
77427739
77437740@torch_op (("aten::softmax.int" , "aten::special_softmax" ), traceable = True )
7744- def aten_softmax_no_dtype (self : TFloatOrBFloat16 , dim : int ) -> TFloatOrBFloat16 :
7741+ def aten_softmax_no_dtype (self : TFloat , dim : int ) -> TFloat :
77457742 """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77467743
77477744 self_is_scalar = IsScalar (self )
@@ -7812,7 +7809,7 @@ def aten_split_with_sizes_copy(
78127809
78137810
78147811@torch_op ("aten::sqrt" , traceable = True )
7815- def aten_sqrt (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7812+ def aten_sqrt (self : TFloat ) -> TFloat :
78167813 """sqrt(Tensor self) -> Tensor"""
78177814
78187815 return op .Sqrt (self )
@@ -8402,7 +8399,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
84028399
84038400
84048401@torch_op ("aten::trunc" )
8405- def aten_trunc (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
8402+ def aten_trunc (self : TFloat ) -> TFloat :
84068403 """trunc(Tensor self) -> Tensor"""
84078404
84088405 # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
0 commit comments