Skip to content
Merged
144 changes: 109 additions & 35 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,51 +1220,68 @@ def aten_binomial(
@torch_op(
(
"aten::bitwise_and.Tensor",
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
"_operator::and_",
),
trace_only=True,
)
def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""

assert self.dtype == other.dtype
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
dtype = self.dtype if self.dtype is not None else other.dtype
assert dtype is not None

if self.dtype.is_integer():
if dtype.is_integer():
return op.BitwiseAnd(self, other)
if self.dtype == ir.DataType.BOOL:
if dtype == ir.DataType.BOOL:
return op.And(self, other)
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")


@torch_op("aten::bitwise_and.Scalar", trace_only=True)
def aten_bitwise_and_scalar(self: TTensor, other: int) -> TTensor:
"""bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor"""

other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
return aten_bitwise_and(self, other_tensor)


@torch_op("aten::bitwise_and.Scalar_Tensor", trace_only=True)
def aten_bitwise_and_scalar_tensor(self: float, other: TTensor) -> TTensor:
"""bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""

self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
return aten_bitwise_and(self_tensor, other)


@torch_op(
(
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
"_operator::__lshift__",
"aten::__lshift__.Scalar",
),
trace_only=True,
)
def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
dtype = self.dtype if self.dtype is not None else other.dtype
assert dtype is not None

# assert other >= 0
if self.dtype.bitwidth == 8:
if dtype.bitwidth == 8:
unsigned_dtype = ir.DataType.UINT8
signed_dtype = ir.DataType.INT8
elif self.dtype.bitwidth == 16:
elif dtype.bitwidth == 16:
unsigned_dtype = ir.DataType.UINT16
signed_dtype = ir.DataType.INT16
elif self.dtype.bitwidth == 32:
elif dtype.bitwidth == 32:
unsigned_dtype = ir.DataType.UINT32
signed_dtype = ir.DataType.INT32
elif self.dtype.bitwidth == 64:
elif dtype.bitwidth == 64:
unsigned_dtype = ir.DataType.UINT64
signed_dtype = ir.DataType.INT64
else:
raise NotImplementedError(f"Not implemented for type {self.dtype}")
raise NotImplementedError(f"Not implemented for type {dtype}")

self = op.Cast(self, to=unsigned_dtype)
other = op.Cast(other, to=unsigned_dtype)
Expand All @@ -1274,6 +1291,22 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
return op.Cast(result, to=signed_dtype)


@torch_op(
("aten::bitwise_left_shift.Tensor_Scalar", "aten::__lshift__.Scalar"), trace_only=True
)
def aten_bitwise_left_shift_tensor_scalar(self: TInt, other: int) -> TInt:
"""bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
return aten_bitwise_left_shift(self, other_tensor)


@torch_op("aten::bitwise_left_shift.Scalar_Tensor", trace_only=True)
def aten_bitwise_left_shift_scalar_tensor(self: int, other: TInt) -> TInt:
"""bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
return aten_bitwise_left_shift(self_tensor, other)


@torch_op("aten::bitwise_not", trace_only=True)
def aten_bitwise_not(self: TTensor) -> TTensor:
"""bitwise_not(Tensor self) -> Tensor"""
Expand All @@ -1288,54 +1321,69 @@ def aten_bitwise_not(self: TTensor) -> TTensor:
@torch_op(
(
"aten::bitwise_or.Tensor",
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"_operator::or_",
),
trace_only=True,
)
def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""

assert self.dtype == other.dtype
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
dtype = self.dtype if self.dtype is not None else other.dtype
assert dtype is not None

if self.dtype.is_integer():
if dtype.is_integer():
return op.BitwiseOr(self, other)
if self.dtype == ir.DataType.BOOL:
if dtype == ir.DataType.BOOL:
return op.Or(self, other)
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")


@torch_op("aten::bitwise_or.Scalar", trace_only=True)
def aten_bitwise_or_scalar(self: TTensor, other: int) -> TTensor:
"""bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor"""
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
return aten_bitwise_or(self, other_tensor)


@torch_op("aten::bitwise_or.Scalar_Tensor", trace_only=True)
def aten_bitwise_or_scalar_tensor(self: int, other: TTensor) -> TTensor:
"""bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
return aten_bitwise_or(self_tensor, other)


@torch_op(
(
"aten::bitwise_right_shift.Tensor",
"aten::bitwise_right_shift.Tensor_Scalar",
"aten::bitwise_right_shift.Scalar_Tensor",
"_operator::__rshift__",
"aten::__rshift__.Scalar",
),
trace_only=True,
)
def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt:
"""bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
if self.dtype.bitwidth == 8:
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
dtype = self.dtype if self.dtype is not None else other.dtype
assert dtype is not None

if dtype.bitwidth == 8:
unsigned_dtype = ir.DataType.UINT8
signed_dtype = ir.DataType.INT8
mask = ir.tensor(0xFF, dtype=unsigned_dtype)
elif self.dtype.bitwidth == 16:
elif dtype.bitwidth == 16:
unsigned_dtype = ir.DataType.UINT16
signed_dtype = ir.DataType.INT16
mask = ir.tensor(0xFFFF, dtype=unsigned_dtype)
elif self.dtype.bitwidth == 32:
elif dtype.bitwidth == 32:
unsigned_dtype = ir.DataType.UINT32
signed_dtype = ir.DataType.INT32
mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype)
elif self.dtype.bitwidth == 64:
elif dtype.bitwidth == 64:
unsigned_dtype = ir.DataType.UINT64
signed_dtype = ir.DataType.INT64
mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF
else:
raise NotImplementedError(f"Not implemented for type {self.dtype}")
raise NotImplementedError(f"Not implemented for type {dtype}")

negative = op.Less(self, 0)
self = op.Cast(self, to=unsigned_dtype)
Expand All @@ -1356,24 +1404,50 @@ def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt:


@torch_op(
(
"aten::bitwise_xor.Tensor",
"aten::bitwise_xor.Scalar",
"aten::bitwise_xor.Scalar_Tensor",
),
trace_only=True,
("aten::bitwise_right_shift.Tensor_Scalar", "aten::__rshift__.Scalar"), trace_only=True
)
def aten_bitwise_right_shift_tensor_scalar(self: TInt, other: int) -> TInt:
"""bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
return aten_bitwise_right_shift(self, other_tensor)


@torch_op("aten::bitwise_right_shift.Scalar_Tensor", trace_only=True)
def aten_bitwise_right_shift_scalar_tensor(self: int, other: TInt) -> TInt:
"""bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
return aten_bitwise_right_shift(self_tensor, other)


@torch_op("aten::bitwise_xor.Tensor", trace_only=True)
def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
assert self.dtype == other.dtype

if self.dtype.is_integer():
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
dtype = self.dtype if self.dtype is not None else other.dtype
assert dtype is not None

if dtype.is_integer():
return op.BitwiseXor(self, other)
if self.dtype == ir.DataType.BOOL:
if dtype == ir.DataType.BOOL:
return op.Xor(self, other)
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")


@torch_op("aten::bitwise_xor.Scalar", trace_only=True)
def aten_bitwise_xor_scalar(self: TTensor, other: int) -> TTensor:
"""bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor"""
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
return aten_bitwise_xor(self, other_tensor)


@torch_op("aten::bitwise_xor.Scalar_Tensor", trace_only=True)
def aten_bitwise_xor_scalar_tensor(self: int, other: TTensor) -> TTensor:
"""bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
return aten_bitwise_xor(self_tensor, other)


@torch_op("aten::blackman_window", trace_only=True)
def aten_blackman_window(
window_length: int,
Expand Down
13 changes: 13 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ def forward(self, q, k, v):
)
_testing.assert_onnx_program(onnx_program)

def test_bitwise_and_scalar(self):
class Model(torch.nn.Module):
def forward(self, x):
return x & 3

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([1, 2, 3, 4, 5]),),
dynamo=True,
verbose=False,
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
Loading