Skip to content

Commit 88b03d8

Browse files
authored
Improve aten_floor_divide for int inputs (#2592)
Fix aten_floor_divide for negative int inputs and large int inputs. I also combined the int and float overloads for #2580 Fix #2589 --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 09bbd27 commit 88b03d8

File tree

3 files changed

+17
-23
lines changed

3 files changed

+17
-23
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3688,23 +3688,27 @@ def python_math_floor(self: TFloat) -> TInt:
36883688

36893689

36903690
@torch_op("aten::floor_divide", trace_only=True)
3691-
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
3691+
def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor:
36923692
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
36933693

3694-
return op.Floor(op.Div(self, other))
3694+
if self.dtype.is_floating_point():
3695+
return op.Floor(op.Div(self, other))
36953696

3697+
assert self.dtype.is_integer()
36963698

3697-
@torch_op("aten::floor_divide", trace_only=True)
3698-
def aten_floor_divide_int(self: TInt, other: TInt) -> TInt:
3699-
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
3699+
if not self.dtype.is_signed():
3700+
return op.Div(self, other)
37003701

3701-
# TODO(justinchuby): This can be simplified if we can constrain the
3702-
# inputs to be positive integers. Consider how we can embed constraints in the model.
3703-
dtype = self.dtype
3704-
self = op.Cast(self, to=FLOAT.dtype)
3705-
other = op.Cast(other, to=FLOAT.dtype)
3706-
result = op.Floor(op.Div(self, other))
3707-
return op.Cast(result, to=dtype)
3702+
# Convert truncation to flooring
3703+
# Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70
3704+
# offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
3705+
# return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
3706+
offset = op.And(
3707+
op.Not(op.Equal(op.Sign(self), op.Sign(other))),
3708+
op.Cast(op.Mod(self, other), to=BOOL.dtype),
3709+
)
3710+
offset = op.Cast(offset, to=self.dtype)
3711+
return op.Sub(op.Div(self, other), offset)
37083712

37093713

37103714
@torch_op("_operator::floordiv", trace_only=True)

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,18 +2270,9 @@ def __init__(self):
22702270
opinfo_core.BinaryUfuncInfo(
22712271
"ops.aten.floor_divide",
22722272
aten_name="floor_divide",
2273-
dtypes=common_dtype.floating_types_and_half(),
2273+
dtypes=common_dtype.all_types_and_half(),
22742274
rhs_make_tensor_kwargs=dict(exclude_zero=True),
22752275
),
2276-
opinfo_core.BinaryUfuncInfo(
2277-
"ops.aten.floor_divide.int",
2278-
aten_name="floor_divide",
2279-
op=torch.ops.aten.floor_divide,
2280-
dtypes=common_dtype.integral_types(),
2281-
# Create only positive inputs
2282-
lhs_make_tensor_kwargs=dict(low=0),
2283-
rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0),
2284-
),
22852276
opinfo_core.OpInfo(
22862277
"ops.aten.hamming_window",
22872278
aten_name="hamming_window",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,6 @@ def _where_input_wrangler(
794794
TorchLibOpInfo("flatten", core_ops.aten_flatten),
795795
TorchLibOpInfo("floor", core_ops.aten_floor),
796796
TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide),
797-
TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int),
798797
TorchLibOpInfo("fmod", core_ops.aten_fmod),
799798
TorchLibOpInfo("frac", core_ops.aten_frac),
800799
TorchLibOpInfo("full", core_ops.aten_full),

0 commit comments

Comments
 (0)