4242 TInt ,
4343 TReal ,
4444 TRealOrUInt8 ,
45+ TRealUnlessFloat16OrInt8 ,
4546 TRealUnlessInt16OrInt8 ,
4647 TTensor ,
4748 TTensor2 ,
@@ -540,7 +541,7 @@ def _integral_to_be_adjusted(dtype: int) -> bool:
540541
541542@torch_op ("aten::arange" , trace_only = True )
542543def aten_arange (
543- end : float ,
544+ end : TRealUnlessFloat16OrInt8 ,
544545 dtype : int = - 1 ,
545546 layout : str = "" ,
546547 device : str = "" ,
@@ -549,10 +550,9 @@ def aten_arange(
549550 """arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
550551
551552 if dtype == - 1 or dtype is None :
552- if isinstance (end , int ):
553- result = op .Range (0 , end , 1 )
554- else :
555- result = op .Range (0.0 , end , 1.0 )
553+ zero = op .CastLike (0.0 , end )
554+ one = op .CastLike (1.0 , end )
555+ result = op .Range (zero , end , one )
556556 elif _range_supported (dtype ):
557557 end = op .Cast (end , to = dtype )
558558 zero = op .Cast (0 , to = dtype )
@@ -563,7 +563,7 @@ def aten_arange(
563563 # because the input dtype may be e.g. bfloat16 / int8 etc.
564564 # which Range does not support. The output type is ensured because the output
565565 # is casted to the specified dtype.
566- end = op .Constant ( value_float = float ( end ) )
566+ end = op .Cast ( end , to = FLOAT . dtype )
567567 zero = op .Constant (value_float = 0.0 )
568568 one = op .Constant (value_float = 1.0 )
569569 result = op .Cast (op .Range (zero , end , one ), to = dtype )
@@ -573,8 +573,8 @@ def aten_arange(
573573
574574@torch_op ("aten::arange.start" , trace_only = True )
575575def aten_arange_start (
576- start : float ,
577- end : float ,
576+ start : TRealUnlessFloat16OrInt8 ,
577+ end : TRealUnlessFloat16OrInt8 ,
578578 dtype : int = - 1 ,
579579 layout : str = "" ,
580580 device : str = "" ,
@@ -583,12 +583,8 @@ def aten_arange_start(
583583 """arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
584584
585585 if dtype == - 1 or dtype is None :
586- if isinstance (start , int ) and isinstance (end , int ):
587- result = op .Range (start , end , 1 )
588- else :
589- start = float (start )
590- end = float (end )
591- result = op .Range (start , end , 1.0 )
586+ one = op .CastLike (1.0 , end )
587+ result = op .Range (start , end , one )
592588 elif _range_supported (dtype ):
593589 end = op .Cast (end , to = dtype )
594590 start = op .Cast (start , to = dtype )
@@ -599,46 +595,78 @@ def aten_arange_start(
599595 # because the input dtype may be e.g. bfloat16 / int8 etc.
600596 # which Range does not support. The output type is ensured because the output
601597 # is casted to the specified dtype.
602- end = op .Constant ( value_float = float ( end ) )
603- start = op .Constant ( value_float = float ( start ) )
598+ end = op .Cast ( end , to = FLOAT . dtype )
599+ start = op .Cast ( start , to = FLOAT . dtype )
604600 one = op .Constant (value_float = 1.0 )
605601 result = op .Cast (op .Range (start , end , one ), to = dtype )
606602
607603 return result
608604
609605
610606def _adjust_args_for_arange_int_dtype (
611- start : float ,
612- end : float ,
613- step : float ,
614- ) -> Tuple [float , float , float ]:
615- if start < 0 :
616- start = math . ceil (start )
617- if step < 0 :
618- start = math . floor ( start )
607+ start : TRealUnlessFloat16OrInt8 ,
608+ end : TRealUnlessFloat16OrInt8 ,
609+ step : TRealUnlessFloat16OrInt8 ,
610+ ) -> Tuple [FLOAT , FLOAT , FLOAT ]:
611+ zero = op . Cast ( 0.0 , to = FLOAT . dtype )
612+ start = op . Cast (start , to = FLOAT . dtype )
613+ end = op . Cast ( end , to = FLOAT . dtype )
614+ step = op . Cast ( step , to = FLOAT . dtype )
619615
620- return float (start ), float (end ), float (step )
616+ start = op .Where (op .Less (start , zero ), op .Ceil (start ), start )
617+ start = op .Where (op .Less (step , zero ), op .Floor (start ), start )
618+
619+ return (start , end , step )
621620
622621
623622@torch_op ("aten::arange.start_step" , trace_only = True )
624623def aten_arange_start_step (
625- start : float ,
626- end : float ,
627- step : float = 1.0 ,
624+ start : TRealUnlessFloat16OrInt8 ,
625+ end : TRealUnlessFloat16OrInt8 ,
626+ step : TRealUnlessFloat16OrInt8 = 1.0 ,
628627 dtype : int = - 1 ,
629628 layout : str = "" ,
630629 device : str = "" ,
631630 pin_memory : bool = False ,
632631) -> TensorType :
633632 """arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
634633
635- if dtype == - 1 or dtype is None :
636- if isinstance (start , int ) and isinstance (end , int ):
637- result = op .Range (start , end , int (step ))
634+ if dtype == - 1 :
635+ # TODO: Because this is a trace_only function, the inputs are not promoted to
636+ # Tensor until it hits ONNX ops. However, if it's dynamic, it should be
637+ # Tensor at this point.
638+ # https://github.com/microsoft/onnxscript/issues/1914
639+ if isinstance (start , (int , float )):
640+ start_is_int = isinstance (start , int )
638641 else :
639- start = float (start )
640- end = float (end )
641- step = float (step )
642+ start_is_int = start .dtype in {
643+ INT16 .dtype ,
644+ INT32 .dtype ,
645+ INT64 .dtype ,
646+ }
647+ if isinstance (end , (int , float )):
648+ end_is_int = isinstance (end , int )
649+ else :
650+ end_is_int = end .dtype in {
651+ INT16 .dtype ,
652+ INT32 .dtype ,
653+ INT64 .dtype ,
654+ }
655+ if isinstance (step , (int , float )):
656+ step_is_int = isinstance (step , int )
657+ else :
658+ step_is_int = step .dtype in {
659+ INT16 .dtype ,
660+ INT32 .dtype ,
661+ INT64 .dtype ,
662+ }
663+ if start_is_int and end_is_int and step_is_int :
664+ result = op .Range (start , end , step )
665+ else :
666+ # to float
667+ start = op .Cast (start , to = FLOAT .dtype )
668+ end = op .Cast (end , to = FLOAT .dtype )
669+ step = op .Cast (step , to = FLOAT .dtype )
642670 result = op .Range (start , end , step )
643671 elif _integral_to_be_adjusted (dtype ):
644672 # PyTorch arange op handles these integral types differently from INT64,
@@ -647,18 +675,18 @@ def aten_arange_start_step(
647675 start , end , step = _adjust_args_for_arange_int_dtype (start , end , step )
648676 result = op .Cast (op .Range (start , end , step ), to = dtype )
649677 elif dtype == INT64 .dtype :
650- end = int (end )
651- start = int (start )
652- step = int (step )
678+ end = op . Cast (end , to = dtype )
679+ start = op . Cast (start , to = dtype )
680+ step = op . Cast (step , to = dtype )
653681 result = op .Range (start , end , step )
654682 else :
655683 # Cast input to float if dtype is not supported by Range,
656684 # because the input dtype may be e.g. bfloat16,
657685 # which Range does not support. The output type is ensured because the output
658686 # is casted to the specified dtype.
659- end = float (end )
660- start = float (start )
661- step = float (step )
687+ end = op . Cast (end , to = FLOAT . dtype )
688+ start = op . Cast (start , to = FLOAT . dtype )
689+ step = op . Cast (step , to = FLOAT . dtype )
662690 result = op .Cast (op .Range (start , end , step ), to = dtype )
663691
664692 return result
@@ -4735,8 +4763,8 @@ def aten_linear_backward(
47354763
47364764@torch_op ("aten::linspace" , trace_only = True )
47374765def aten_linspace (
4738- start : float ,
4739- end : float ,
4766+ start : TFloat ,
4767+ end : TFloat ,
47404768 steps : int ,
47414769 dtype : int = FLOAT .dtype ,
47424770 layout : str = "" ,
@@ -4754,7 +4782,6 @@ def aten_linspace(
47544782 if steps == 1 :
47554783 return aten_full (op .Constant (value_ints = [steps ]), start , dtype = dtype )
47564784
4757- # TODO(justinchuby): Simplify the logic knowing start and end are floats
47584785 rg = aten_arange_start (0 , steps , dtype = dtype )
47594786 start = op .Cast (start , to = dtype )
47604787 end = op .Cast (end , to = dtype )
0 commit comments