@@ -3688,23 +3688,27 @@ def python_math_floor(self: TFloat) -> TInt:
3688
3688
3689
3689
3690
3690
@torch_op ("aten::floor_divide" , trace_only = True )
3691
- def aten_floor_divide (self : TFloat , other : TFloat ) -> TFloat :
3691
+ def aten_floor_divide (self : TTensor , other : TTensor ) -> TTensor :
3692
3692
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
3693
3693
3694
- return op .Floor (op .Div (self , other ))
3694
+ if self .dtype .is_floating_point ():
3695
+ return op .Floor (op .Div (self , other ))
3695
3696
3697
+ assert self .dtype .is_integer ()
3696
3698
3697
- @torch_op ("aten::floor_divide" , trace_only = True )
3698
- def aten_floor_divide_int (self : TInt , other : TInt ) -> TInt :
3699
- """floor_divide(Tensor self, Tensor other) -> Tensor"""
3699
+ if not self .dtype .is_signed ():
3700
+ return op .Div (self , other )
3700
3701
3701
- # TODO(justinchuby): This can be simplified if we can constrain the
3702
- # inputs to be positive integers. Consider how we can embed constraints in the model.
3703
- dtype = self .dtype
3704
- self = op .Cast (self , to = FLOAT .dtype )
3705
- other = op .Cast (other , to = FLOAT .dtype )
3706
- result = op .Floor (op .Div (self , other ))
3707
- return op .Cast (result , to = dtype )
3702
+ # Convert truncation to flooring
3703
+ # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70
3704
+ # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
3705
+ # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
3706
+ offset = op .And (
3707
+ op .Not (op .Equal (op .Sign (self ), op .Sign (other ))),
3708
+ op .Cast (op .Mod (self , other ), to = BOOL .dtype ),
3709
+ )
3710
+ offset = op .Cast (offset , to = self .dtype )
3711
+ return op .Sub (op .Div (self , other ), offset )
3708
3712
3709
3713
3710
3714
@torch_op ("_operator::floordiv" , trace_only = True )
0 commit comments