Skip to content

Commit a833dc6

Browse files
committed
revert the whole family
1 parent 8295a77 commit a833dc6

File tree

1 file changed

+35
-37
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+35
-37
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
TInt,
4343
TReal,
4444
TRealOrUInt8,
45+
TRealUnlessFloat16OrInt8,
4546
TRealUnlessInt16OrInt8,
4647
TTensor,
4748
TTensor2,
@@ -540,7 +541,7 @@ def _integral_to_be_adjusted(dtype: int) -> bool:
540541

541542
@torch_op("aten::arange", trace_only=True)
542543
def aten_arange(
543-
end: float,
544+
end: TRealUnlessFloat16OrInt8,
544545
dtype: int = -1,
545546
layout: str = "",
546547
device: str = "",
@@ -549,10 +550,9 @@ def aten_arange(
549550
"""arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
550551

551552
if dtype == -1 or dtype is None:
552-
if isinstance(end, int):
553-
result = op.Range(0, end, 1)
554-
else:
555-
result = op.Range(0.0, end, 1.0)
553+
zero = op.CastLike(0.0, end)
554+
one = op.CastLike(1.0, end)
555+
result = op.Range(zero, end, one)
556556
elif _range_supported(dtype):
557557
end = op.Cast(end, to=dtype)
558558
zero = op.Cast(0, to=dtype)
@@ -563,7 +563,7 @@ def aten_arange(
563563
# because the input dtype may be e.g. bfloat16 / int8 etc.
564564
# which Range does not support. The output type is ensured because the output
565565
# is casted to the specified dtype.
566-
end = op.Constant(value_float=float(end))
566+
end = op.Cast(end, to=FLOAT.dtype)
567567
zero = op.Constant(value_float=0.0)
568568
one = op.Constant(value_float=1.0)
569569
result = op.Cast(op.Range(zero, end, one), to=dtype)
@@ -573,8 +573,8 @@ def aten_arange(
573573

574574
@torch_op("aten::arange.start", trace_only=True)
575575
def aten_arange_start(
576-
start: TReal,
577-
end: TReal,
576+
start: TRealUnlessFloat16OrInt8,
577+
end: TRealUnlessFloat16OrInt8,
578578
dtype: int = -1,
579579
layout: str = "",
580580
device: str = "",
@@ -604,57 +604,56 @@ def aten_arange_start(
604604

605605

606606
def _adjust_args_for_arange_int_dtype(
607-
start: float,
608-
end: float,
609-
step: float,
610-
) -> Tuple[float, float, float]:
611-
if start < 0:
612-
start = math.ceil(start)
613-
if step < 0:
614-
start = math.floor(start)
607+
start: TRealUnlessFloat16OrInt8,
608+
end: TRealUnlessFloat16OrInt8,
609+
step: TRealUnlessFloat16OrInt8,
610+
) -> Tuple[FLOAT, FLOAT, FLOAT]:
611+
zero = op.Cast(0.0, to=FLOAT.dtype)
612+
start = op.Cast(start, to=FLOAT.dtype)
613+
end = op.Cast(end, to=FLOAT.dtype)
614+
step = op.Cast(step, to=FLOAT.dtype)
615615

616-
return float(start), float(end), float(step)
616+
start = op.Where(op.Less(start, zero), op.Ceil(start), start)
617+
start = op.Where(op.Less(step, zero), op.Floor(start), start)
618+
619+
return (start, end, step)
617620

618621

619622
@torch_op("aten::arange.start_step", trace_only=True)
620623
def aten_arange_start_step(
621-
start: float,
622-
end: float,
623-
step: float = 1.0,
624+
start: TRealUnlessFloat16OrInt8,
625+
end: TRealUnlessFloat16OrInt8,
626+
step: TRealUnlessFloat16OrInt8 = 1.0,
624627
dtype: int = -1,
625628
layout: str = "",
626629
device: str = "",
627630
pin_memory: bool = False,
628631
) -> TensorType:
629632
"""arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
630633

631-
if dtype == -1 or dtype is None:
632-
if isinstance(start, int) and isinstance(end, int):
633-
result = op.Range(start, end, int(step))
634-
else:
635-
start = float(start)
636-
end = float(end)
637-
step = float(step)
638-
result = op.Range(start, end, step)
634+
if dtype == -1:
635+
start = op.Cast(start, to=FLOAT.dtype)
636+
end = op.Cast(end, to=FLOAT.dtype)
637+
result = op.Range(start, end, step)
639638
elif _integral_to_be_adjusted(dtype):
640639
# PyTorch arange op handles these integral types differently from INT64,
641640
# so we have to adjust these arguments accordingly.
642641
# https://github.com/pytorch/pytorch/blob/121cfb60c0817816fcbe2190303b7f6d05c77cf3/torch/_refs/__init__.py#L4794
643642
start, end, step = _adjust_args_for_arange_int_dtype(start, end, step)
644643
result = op.Cast(op.Range(start, end, step), to=dtype)
645644
elif dtype == INT64.dtype:
646-
end = int(end)
647-
start = int(start)
648-
step = int(step)
645+
end = op.Cast(end, to=dtype)
646+
start = op.Cast(start, to=dtype)
647+
step = op.Cast(step, to=dtype)
649648
result = op.Range(start, end, step)
650649
else:
651650
# Cast input to float if dtype is not supported by Range,
652651
# because the input dtype may be e.g. bfloat16,
653652
# which Range does not support. The output type is ensured because the output
654653
# is casted to the specified dtype.
655-
end = float(end)
656-
start = float(start)
657-
step = float(step)
654+
end = op.Cast(end, to=FLOAT.dtype)
655+
start = op.Cast(start, to=FLOAT.dtype)
656+
step = op.Cast(step, to=FLOAT.dtype)
658657
result = op.Cast(op.Range(start, end, step), to=dtype)
659658

660659
return result
@@ -4731,8 +4730,8 @@ def aten_linear_backward(
47314730

47324731
@torch_op("aten::linspace", trace_only=True)
47334732
def aten_linspace(
4734-
start: float,
4735-
end: float,
4733+
start: TFloat,
4734+
end: TFloat,
47364735
steps: int,
47374736
dtype: int = FLOAT.dtype,
47384737
layout: str = "",
@@ -4750,7 +4749,6 @@ def aten_linspace(
47504749
if steps == 1:
47514750
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)
47524751

4753-
# TODO(justinchuby): Simplify the logic knowing start and end are floats
47544752
rg = aten_arange_start(0, steps, dtype=dtype)
47554753
start = op.Cast(start, to=dtype)
47564754
end = op.Cast(end, to=dtype)

0 commit comments

Comments
 (0)