3939 RealType ,
4040 TFloat ,
4141 TFloatHighPrecision ,
42- TFloatOrBFloat16 ,
4342 TInt ,
4443 TReal ,
4544 TRealOrUInt8 ,
@@ -2031,12 +2030,6 @@ def aten_convolution(
20312030 stride = (stride , stride )
20322031 strides = list (stride )
20332032
2034- if bias is None :
2035- weight_dim_0 = op .Shape (weight , start = 0 , end = 1 )
2036- bias_shape = op .Expand (weight_dim_0 , op .Constant (value_ints = [1 ]))
2037- zero = op .CastLike (0.0 , input )
2038- bias = op .Expand (zero , bias_shape )
2039-
20402033 result = _aten_convolution_onnx (
20412034 input ,
20422035 weight ,
@@ -3564,14 +3557,14 @@ def aten_flipud(self: TensorType) -> TensorType:
35643557
35653558
35663559@torch_op ("aten::floor" , traceable = True )
3567- def aten_floor (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
3560+ def aten_floor (self : TFloat ) -> TFloat :
35683561 """floor(Tensor self) -> Tensor"""
35693562
35703563 return op .Floor (self )
35713564
35723565
35733566@torch_op ("math::floor" , traceable = True )
3574- def python_math_floor (self : TFloatOrBFloat16 ) -> TInt :
3567+ def python_math_floor (self : TFloat ) -> TInt :
35753568 """floor(Tensor self) -> Tensor"""
35763569 floor = op .Floor (self )
35773570 return op .Cast (floor , to = INT64 .dtype )
@@ -4533,7 +4526,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:
45334526
45344527
45354528@torch_op ("aten::isinf" )
4536- def aten_isinf (self : TFloatOrBFloat16 ) -> BOOL :
4529+ def aten_isinf (self : TFloat ) -> BOOL :
45374530 """isinf(Tensor self) -> Tensor"""
45384531
45394532 # Added Cast inside the function so it can support all real dtypes naturally
@@ -4542,14 +4535,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
45424535
45434536
45444537@torch_op ("aten::isnan" )
4545- def aten_isnan (self : TFloatOrBFloat16 ) -> BOOL :
4538+ def aten_isnan (self : TFloat ) -> BOOL :
45464539 """isnan(Tensor self) -> Tensor"""
45474540
45484541 return op .IsNaN (self )
45494542
45504543
45514544@torch_op ("aten::isneginf" )
4552- def aten_isneginf (self : TFloatOrBFloat16 ) -> BOOL :
4545+ def aten_isneginf (self : TFloat ) -> BOOL :
45534546 """isneginf(Tensor self) -> Tensor"""
45544547
45554548 # Added Cast inside the function so it can support all real dtypes naturally
@@ -4558,7 +4551,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
45584551
45594552
45604553@torch_op ("aten::isposinf" )
4561- def aten_isposinf (self : TFloatOrBFloat16 ) -> BOOL :
4554+ def aten_isposinf (self : TFloat ) -> BOOL :
45624555 """isposinf(Tensor self) -> Tensor"""
45634556
45644557 # Added Cast inside the function so it can support all real dtypes naturally
@@ -4778,42 +4771,42 @@ def aten_linspace(
47784771
47794772
47804773@torch_op ("aten::log" , traceable = True )
4781- def aten_log (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4774+ def aten_log (self : TFloat ) -> TFloat :
47824775 """log(Tensor self) -> Tensor"""
47834776
47844777 return op .Log (self )
47854778
47864779
47874780@torch_op ("aten::log10" , traceable = True )
4788- def aten_log10 (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4781+ def aten_log10 (self : TFloat ) -> TFloat :
47894782 """log10(Tensor self) -> Tensor"""
47904783
47914784 return op .Div (op .Log (self ), op .CastLike (op .Log (10.0 ), self ))
47924785
47934786
47944787@torch_op ("aten::log1p" )
4795- def aten_log1p (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4788+ def aten_log1p (self : TFloat ) -> TFloat :
47964789 """log1p(Tensor self) -> Tensor"""
47974790
47984791 return op .Log (op .Add (self , 1.0 ))
47994792
48004793
48014794@torch_op ("aten::log2" , traceable = True )
4802- def aten_log2 (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4795+ def aten_log2 (self : TFloat ) -> TFloat :
48034796 """log2(Tensor self) -> Tensor"""
48044797
48054798 return op .Div (op .Log (self ), op .CastLike (op .Log (2.0 ), self ))
48064799
48074800
48084801@torch_op ("aten::logaddexp" , traceable = True )
4809- def aten_logaddexp (self : TFloatOrBFloat16 , other : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4802+ def aten_logaddexp (self : TFloat , other : TFloat ) -> TFloat :
48104803 """logaddexp(Tensor self, Tensor other) -> Tensor"""
48114804
48124805 return op .Log (op .Add (op .Exp (self ), op .Exp (other )))
48134806
48144807
48154808@torch_op ("aten::logaddexp2" , traceable = True )
4816- def aten_logaddexp2 (self : TFloatOrBFloat16 , other : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4809+ def aten_logaddexp2 (self : TFloat , other : TFloat ) -> TFloat :
48174810 """logaddexp2(Tensor self, Tensor other) -> Tensor"""
48184811 two = op .CastLike (2.0 , self )
48194812 summation = op .Add (op .Pow (two , self ), op .Pow (two , other ))
@@ -4822,7 +4815,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr
48224815
48234816
48244817@torch_op ("aten::logcumsumexp" , traceable = True )
4825- def aten_logcumsumexp (self : TFloatOrBFloat16 , dim : int ) -> TFloatOrBFloat16 :
4818+ def aten_logcumsumexp (self : TFloat , dim : int ) -> TFloat :
48264819 """logcumsumexp(Tensor self, int dim) -> Tensor"""
48274820
48284821 if IsScalar (self ):
@@ -4908,12 +4901,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
49084901
49094902
49104903@torch_op ("aten::logit" , private = True )
4911- def _aten_logit_onnx (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
4904+ def _aten_logit_onnx (self : TFloat ) -> TFloat :
49124905 return op .Log (op .Div (self , op .Sub (1.0 , self )))
49134906
49144907
49154908@torch_op ("aten::logit" , private = True )
4916- def _aten_logit_clamp_onnx (self : TFloatOrBFloat16 , eps : float ) -> TFloatOrBFloat16 :
4909+ def _aten_logit_clamp_onnx (self : TFloat , eps : float ) -> TFloat :
49174910 eps = op .CastLike (eps , self )
49184911 one = op .CastLike (1.0 , self )
49194912 temporary_self = op .Where (self <= one - eps , self , one - eps )
@@ -4923,7 +4916,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat
49234916
49244917
49254918@torch_op ("aten::logit" , trace_only = True )
4926- def aten_logit (self : TFloatOrBFloat16 , eps : Optional [float ] = None ) -> TFloatOrBFloat16 :
4919+ def aten_logit (self : TFloat , eps : Optional [float ] = None ) -> TFloat :
49274920 """logit(Tensor self, float? eps=None) -> Tensor"""
49284921 if eps is None :
49294922 return _aten_logit_onnx (self )
@@ -6041,9 +6034,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
60416034
60426035
60436036@torch_op ("aten::native_dropout" , trace_only = True )
6044- def aten_native_dropout (
6045- input : TFloatOrBFloat16 , p : float , train : bool = True
6046- ) -> Tuple [TFloatOrBFloat16 , BOOL ]:
6037+ def aten_native_dropout (input : TFloat , p : float , train : bool = True ) -> Tuple [TFloat , BOOL ]:
60476038 """native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""
60486039
60496040 result , mask = op .Dropout (input , p , train )
@@ -7055,7 +7046,7 @@ def aten_real(self: TensorType) -> TensorType:
70557046
70567047
70577048@torch_op ("aten::reciprocal" )
7058- def aten_reciprocal (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7049+ def aten_reciprocal (self : TFloat ) -> TFloat :
70597050 """reciprocal(Tensor self) -> Tensor"""
70607051
70617052 return op .Reciprocal (self )
@@ -7074,7 +7065,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
70747065
70757066
70767067@torch_op (("aten::remainder.Tensor" , "aten::remainder.Scalar" ))
7077- def aten_remainder (self : TFloatOrBFloat16 , other : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7068+ def aten_remainder (self : TFloat , other : TFloat ) -> TFloat :
70787069 """remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
70797070
70807071 # TODO(justinchuby): Improve fp16 precision by following the logic in
@@ -7355,7 +7346,7 @@ def aten_rrelu(
73557346
73567347
73577348@torch_op ("aten::rsqrt" , traceable = True )
7358- def aten_rsqrt (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7349+ def aten_rsqrt (self : TFloat ) -> TFloat :
73597350 """rsqrt(Tensor self) -> Tensor"""
73607351
73617352 return op .Reciprocal (op .Sqrt (self ))
@@ -7562,7 +7553,7 @@ def aten_sgn(self: TensorType) -> TensorType:
75627553
75637554
75647555@torch_op ("aten::sigmoid" , traceable = True )
7565- def aten_sigmoid (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7556+ def aten_sigmoid (self : TFloat ) -> TFloat :
75667557 """sigmoid(Tensor self) -> Tensor"""
75677558
75687559 return op .Sigmoid (self )
@@ -7724,7 +7715,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
77247715
77257716
77267717@torch_op (("aten::softmax.int" , "aten::special_softmax" ), trace_only = True )
7727- def aten_softmax (self : TFloatOrBFloat16 , dim : int , dtype : int = - 1 ) -> TFloatOrBFloat16 :
7718+ def aten_softmax (self : TFloat , dim : int , dtype : int = - 1 ) -> TFloat :
77287719 """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77297720
77307721 self_is_scalar = IsScalar (self )
@@ -7741,7 +7732,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB
77417732
77427733
77437734@torch_op (("aten::softmax.int" , "aten::special_softmax" ), traceable = True )
7744- def aten_softmax_no_dtype (self : TFloatOrBFloat16 , dim : int ) -> TFloatOrBFloat16 :
7735+ def aten_softmax_no_dtype (self : TFloat , dim : int ) -> TFloat :
77457736 """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77467737
77477738 self_is_scalar = IsScalar (self )
@@ -7812,7 +7803,7 @@ def aten_split_with_sizes_copy(
78127803
78137804
78147805@torch_op ("aten::sqrt" , traceable = True )
7815- def aten_sqrt (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7806+ def aten_sqrt (self : TFloat ) -> TFloat :
78167807 """sqrt(Tensor self) -> Tensor"""
78177808
78187809 return op .Sqrt (self )
@@ -8402,7 +8393,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
84028393
84038394
84048395@torch_op ("aten::trunc" )
8405- def aten_trunc (self : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
8396+ def aten_trunc (self : TFloat ) -> TFloat :
84068397 """trunc(Tensor self) -> Tensor"""
84078398
84088399 # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
0 commit comments