@@ -162,9 +162,15 @@ def aten_acosh(self: TFloat) -> TFloat:
162
162
163
163
164
164
@torch_op (("aten::add.Tensor" , "aten::add.Scalar" , "_operator::add" ), trace_only = True )
165
- def aten_add (self : TReal , other : TReal , alpha : float = 1.0 ) -> TReal :
165
+ def aten_add (self : TTensor , other : TTensor , alpha : float = 1.0 ) -> TTensor :
166
166
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
167
- # TODO(microsoft/onnxruntime#15977): Improve fp16 precision
167
+
168
+ if self .dtype == ir .DataType .BOOL :
169
+ # alpha can also be bool
170
+ if alpha == 0 :
171
+ return op .Identity (self )
172
+ return op .Or (self , other )
173
+
168
174
if alpha != 1.0 :
169
175
alpha = op .CastLike (alpha , other )
170
176
other = op .Mul (other , alpha )
@@ -1237,11 +1243,16 @@ def aten_binomial(
1237
1243
),
1238
1244
trace_only = True ,
1239
1245
)
1240
- def aten_bitwise_and (self : TInt , other : TInt ) -> TInt :
1246
+ def aten_bitwise_and (self : TTensor , other : TTensor ) -> TTensor :
1241
1247
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
1242
- # logical_and implements the BOOL variant
1243
1248
1244
- return op .BitwiseAnd (self , other )
1249
+ assert self .dtype == other .dtype
1250
+
1251
+ if self .dtype .is_integer ():
1252
+ return op .BitwiseAnd (self , other )
1253
+ if self .dtype == ir .DataType .BOOL :
1254
+ return op .And (self , other )
1255
+ raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
1245
1256
1246
1257
1247
1258
@torch_op (
@@ -1329,11 +1340,14 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
1329
1340
1330
1341
1331
1342
@torch_op ("aten::bitwise_not" , trace_only = True )
1332
- def aten_bitwise_not (self : TInt ) -> TInt :
1343
+ def aten_bitwise_not (self : TTensor ) -> TTensor :
1333
1344
"""bitwise_not(Tensor self) -> Tensor"""
1334
- # logical_not implements the BOOL variant
1335
1345
1336
- return op .BitwiseNot (self )
1346
+ if self .dtype == ir .DataType .BOOL :
1347
+ return op .Not (self )
1348
+ if self .dtype .is_integer ():
1349
+ return op .BitwiseNot (self )
1350
+ raise NotImplementedError (f"Not implemented for type { self .dtype } " )
1337
1351
1338
1352
1339
1353
@torch_op (
@@ -1345,11 +1359,16 @@ def aten_bitwise_not(self: TInt) -> TInt:
1345
1359
),
1346
1360
trace_only = True ,
1347
1361
)
1348
- def aten_bitwise_or (self : TInt , other : TInt ) -> TInt :
1362
+ def aten_bitwise_or (self : TTensor , other : TTensor ) -> TTensor :
1349
1363
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
1350
- # logical_or implements the BOOL variant
1351
1364
1352
- return op .BitwiseOr (self , other )
1365
+ assert self .dtype == other .dtype
1366
+
1367
+ if self .dtype .is_integer ():
1368
+ return op .BitwiseOr (self , other )
1369
+ if self .dtype == ir .DataType .BOOL :
1370
+ return op .Or (self , other )
1371
+ raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
1353
1372
1354
1373
1355
1374
@torch_op (
@@ -1487,11 +1506,15 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
1487
1506
),
1488
1507
trace_only = True ,
1489
1508
)
1490
- def aten_bitwise_xor (self : TInt , other : TInt ) -> TInt :
1509
+ def aten_bitwise_xor (self : TTensor , other : TTensor ) -> TTensor :
1491
1510
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
1492
- # logical_xor implements the BOOL variant
1511
+ assert self . dtype == other . dtype
1493
1512
1494
- return op .BitwiseXor (self , other )
1513
+ if self .dtype .is_integer ():
1514
+ return op .BitwiseXor (self , other )
1515
+ if self .dtype == ir .DataType .BOOL :
1516
+ return op .Xor (self , other )
1517
+ raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
1495
1518
1496
1519
1497
1520
@torch_op ("aten::blackman_window" , trace_only = True )
@@ -5010,58 +5033,46 @@ def aten_logdet(self: TFloat) -> TFloat:
5010
5033
return op .Log (op .Det (self ))
5011
5034
5012
5035
5013
- @torch_op (
5014
- (
5015
- "aten::logical_and" ,
5016
- "aten::bitwise_and.Tensor" ,
5017
- "aten::bitwise_and.Scalar" ,
5018
- "aten::bitwise_and.Scalar_Tensor" ,
5019
- ),
5020
- trace_only = True ,
5021
- )
5022
- def aten_logical_and (self : BOOL , other : BOOL ) -> BOOL :
5036
+ @torch_op ("aten::logical_and" , trace_only = True )
5037
+ def aten_logical_and (self : TTensor , other : TTensor ) -> BOOL :
5023
5038
"""logical_and(Tensor self, Tensor other) -> Tensor"""
5024
5039
5025
- return op .And (self , other )
5040
+ assert self .dtype == other .dtype
5041
+
5042
+ if self .dtype == ir .DataType .BOOL :
5043
+ return op .And (self , other )
5044
+ return op .And (op .Cast (self , to = BOOL .dtype ), op .Cast (other , to = BOOL .dtype ))
5026
5045
5027
5046
5028
- @torch_op (( "aten::logical_not" , "aten::bitwise_not" ) , trace_only = True )
5029
- def aten_logical_not (self : BOOL ) -> BOOL :
5047
+ @torch_op ("aten::logical_not" , trace_only = True )
5048
+ def aten_logical_not (self : TTensor ) -> BOOL :
5030
5049
"""logical_not(Tensor self) -> Tensor"""
5031
5050
5032
- return op .Not (self )
5051
+ if self .dtype == ir .DataType .BOOL :
5052
+ return op .Not (self )
5053
+ return op .Not (op .Cast (self , to = BOOL .dtype ))
5033
5054
5034
5055
5035
- @torch_op (
5036
- (
5037
- "aten::logical_or" ,
5038
- "aten::bitwise_or.Tensor" ,
5039
- "aten::bitwise_or.Scalar" ,
5040
- "aten::bitwise_or.Scalar_Tensor" ,
5041
- "aten::add.Tensor" ,
5042
- "aten::add.Scalar" ,
5043
- ),
5044
- trace_only = True ,
5045
- )
5046
- def aten_logical_or (self : BOOL , other : BOOL ) -> BOOL :
5056
+ @torch_op ("aten::logical_or" , trace_only = True )
5057
+ def aten_logical_or (self : TTensor , other : TTensor ) -> BOOL :
5047
5058
"""logical_or(Tensor self, Tensor other) -> Tensor"""
5048
5059
5049
- return op . Or ( self , other )
5060
+ assert self . dtype == other . dtype
5050
5061
5062
+ if self .dtype == ir .DataType .BOOL :
5063
+ return op .Or (self , other )
5064
+ return op .Or (op .Cast (self , to = BOOL .dtype ), op .Cast (other , to = BOOL .dtype ))
5051
5065
5052
- @torch_op (
5053
- (
5054
- "aten::logical_xor" ,
5055
- "aten::bitwise_xor.Tensor" ,
5056
- "aten::bitwise_xor.Scalar" ,
5057
- "aten::bitwise_xor.Scalar_Tensor" ,
5058
- ),
5059
- trace_only = True ,
5060
- )
5061
- def aten_logical_xor (self : BOOL , other : BOOL ) -> BOOL :
5066
+
5067
+ @torch_op ("aten::logical_xor" , trace_only = True )
5068
+ def aten_logical_xor (self : TTensor , other : TTensor ) -> BOOL :
5062
5069
"""logical_xor(Tensor self, Tensor other) -> Tensor"""
5063
5070
5064
- return op .Xor (self , other )
5071
+ assert self .dtype == other .dtype
5072
+
5073
+ if self .dtype == ir .DataType .BOOL :
5074
+ return op .Xor (self , other )
5075
+ return op .Xor (op .Cast (self , to = BOOL .dtype ), op .Cast (other , to = BOOL .dtype ))
5065
5076
5066
5077
5067
5078
@torch_op ("aten::logit" , private = True )
0 commit comments