4242 TInt ,
4343 TReal ,
4444 TRealOrUInt8 ,
45+ TRealUnlessFloat16OrInt8 ,
4546 TRealUnlessInt16OrInt8 ,
4647 TTensor ,
4748 TTensor2 ,
@@ -540,7 +541,7 @@ def _integral_to_be_adjusted(dtype: int) -> bool:
540541
541542@torch_op ("aten::arange" , trace_only = True )
542543def aten_arange (
543- end : float ,
544+ end : TRealUnlessFloat16OrInt8 ,
544545 dtype : int = - 1 ,
545546 layout : str = "" ,
546547 device : str = "" ,
@@ -549,10 +550,9 @@ def aten_arange(
549550 """arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
550551
551552 if dtype == - 1 or dtype is None :
552- if isinstance (end , int ):
553- result = op .Range (0 , end , 1 )
554- else :
555- result = op .Range (0.0 , end , 1.0 )
553+ zero = op .CastLike (0.0 , end )
554+ one = op .CastLike (1.0 , end )
555+ result = op .Range (zero , end , one )
556556 elif _range_supported (dtype ):
557557 end = op .Cast (end , to = dtype )
558558 zero = op .Cast (0 , to = dtype )
@@ -563,7 +563,7 @@ def aten_arange(
563563 # because the input dtype may be e.g. bfloat16 / int8 etc.
564564 # which Range does not support. The output type is ensured because the output
565565 # is casted to the specified dtype.
566- end = op .Constant ( value_float = float ( end ) )
566+ end = op .Cast ( end , to = FLOAT . dtype )
567567 zero = op .Constant (value_float = 0.0 )
568568 one = op .Constant (value_float = 1.0 )
569569 result = op .Cast (op .Range (zero , end , one ), to = dtype )
@@ -573,8 +573,8 @@ def aten_arange(
573573
574574@torch_op ("aten::arange.start" , trace_only = True )
575575def aten_arange_start (
576- start : TReal ,
577- end : TReal ,
576+ start : TRealUnlessFloat16OrInt8 ,
577+ end : TRealUnlessFloat16OrInt8 ,
578578 dtype : int = - 1 ,
579579 layout : str = "" ,
580580 device : str = "" ,
@@ -604,57 +604,56 @@ def aten_arange_start(
604604
605605
606606def _adjust_args_for_arange_int_dtype (
607- start : float ,
608- end : float ,
609- step : float ,
610- ) -> Tuple [float , float , float ]:
611- if start < 0 :
612- start = math . ceil (start )
613- if step < 0 :
614- start = math . floor ( start )
607+ start : TRealUnlessFloat16OrInt8 ,
608+ end : TRealUnlessFloat16OrInt8 ,
609+ step : TRealUnlessFloat16OrInt8 ,
610+ ) -> Tuple [FLOAT , FLOAT , FLOAT ]:
611+ zero = op . Cast ( 0.0 , to = FLOAT . dtype )
612+ start = op . Cast (start , to = FLOAT . dtype )
613+ end = op . Cast ( end , to = FLOAT . dtype )
614+ step = op . Cast ( step , to = FLOAT . dtype )
615615
616- return float (start ), float (end ), float (step )
616+ start = op .Where (op .Less (start , zero ), op .Ceil (start ), start )
617+ start = op .Where (op .Less (step , zero ), op .Floor (start ), start )
618+
619+ return (start , end , step )
617620
618621
619622@torch_op ("aten::arange.start_step" , trace_only = True )
620623def aten_arange_start_step (
621- start : float ,
622- end : float ,
623- step : float = 1.0 ,
624+ start : TRealUnlessFloat16OrInt8 ,
625+ end : TRealUnlessFloat16OrInt8 ,
626+ step : TRealUnlessFloat16OrInt8 = 1.0 ,
624627 dtype : int = - 1 ,
625628 layout : str = "" ,
626629 device : str = "" ,
627630 pin_memory : bool = False ,
628631) -> TensorType :
629632 """arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
630633
631- if dtype == - 1 or dtype is None :
632- if isinstance (start , int ) and isinstance (end , int ):
633- result = op .Range (start , end , int (step ))
634- else :
635- start = float (start )
636- end = float (end )
637- step = float (step )
638- result = op .Range (start , end , step )
634+ if dtype == - 1 :
635+ start = op .Cast (start , to = FLOAT .dtype )
636+ end = op .Cast (end , to = FLOAT .dtype )
637+ result = op .Range (start , end , step )
639638 elif _integral_to_be_adjusted (dtype ):
640639 # PyTorch arange op handles these integral types differently from INT64,
641640 # so we have to adjust these arguments accordingly.
642641 # https://github.com/pytorch/pytorch/blob/121cfb60c0817816fcbe2190303b7f6d05c77cf3/torch/_refs/__init__.py#L4794
643642 start , end , step = _adjust_args_for_arange_int_dtype (start , end , step )
644643 result = op .Cast (op .Range (start , end , step ), to = dtype )
645644 elif dtype == INT64 .dtype :
646- end = int (end )
647- start = int (start )
648- step = int (step )
645+ end = op . Cast (end , to = dtype )
646+ start = op . Cast (start , to = dtype )
647+ step = op . Cast (step , to = dtype )
649648 result = op .Range (start , end , step )
650649 else :
651650 # Cast input to float if dtype is not supported by Range,
652651 # because the input dtype may be e.g. bfloat16,
653652 # which Range does not support. The output type is ensured because the output
654653 # is casted to the specified dtype.
655- end = float (end )
656- start = float (start )
657- step = float (step )
654+ end = op . Cast (end , to = FLOAT . dtype )
655+ start = op . Cast (start , to = FLOAT . dtype )
656+ step = op . Cast (step , to = FLOAT . dtype )
658657 result = op .Cast (op .Range (start , end , step ), to = dtype )
659658
660659 return result
@@ -4731,8 +4730,8 @@ def aten_linear_backward(
47314730
47324731@torch_op ("aten::linspace" , trace_only = True )
47334732def aten_linspace (
4734- start : float ,
4735- end : float ,
4733+ start : TFloat ,
4734+ end : TFloat ,
47364735 steps : int ,
47374736 dtype : int = FLOAT .dtype ,
47384737 layout : str = "" ,
@@ -4750,7 +4749,6 @@ def aten_linspace(
47504749 if steps == 1 :
47514750 return aten_full (op .Constant (value_ints = [steps ]), start , dtype = dtype )
47524751
4753- # TODO(justinchuby): Simplify the logic knowing start and end are floats
47544752 rg = aten_arange_start (0 , steps , dtype = dtype )
47554753 start = op .Cast (start , to = dtype )
47564754 end = op .Cast (end , to = dtype )
0 commit comments