Skip to content

Commit 6b0b324

Browse files
authored
Merge branch 'main' into titaiwang/make_trace_function_promote_constant
2 parents ed7636f + 2b60939 commit 6b0b324

File tree

10 files changed

+229
-71
lines changed

10 files changed

+229
-71
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
TInt,
4343
TReal,
4444
TRealOrUInt8,
45+
TRealUnlessFloat16OrInt8,
4546
TRealUnlessInt16OrInt8,
4647
TTensor,
4748
TTensor2,
@@ -540,7 +541,7 @@ def _integral_to_be_adjusted(dtype: int) -> bool:
540541

541542
@torch_op("aten::arange", trace_only=True)
542543
def aten_arange(
543-
end: float,
544+
end: TRealUnlessFloat16OrInt8,
544545
dtype: int = -1,
545546
layout: str = "",
546547
device: str = "",
@@ -549,10 +550,9 @@ def aten_arange(
549550
"""arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
550551

551552
if dtype == -1 or dtype is None:
552-
if isinstance(end, int):
553-
result = op.Range(0, end, 1)
554-
else:
555-
result = op.Range(0.0, end, 1.0)
553+
zero = op.CastLike(0.0, end)
554+
one = op.CastLike(1.0, end)
555+
result = op.Range(zero, end, one)
556556
elif _range_supported(dtype):
557557
end = op.Cast(end, to=dtype)
558558
zero = op.Cast(0, to=dtype)
@@ -563,7 +563,7 @@ def aten_arange(
563563
# because the input dtype may be e.g. bfloat16 / int8 etc.
564564
# which Range does not support. The output type is ensured because the output
565565
# is casted to the specified dtype.
566-
end = op.Constant(value_float=float(end))
566+
end = op.Cast(end, to=FLOAT.dtype)
567567
zero = op.Constant(value_float=0.0)
568568
one = op.Constant(value_float=1.0)
569569
result = op.Cast(op.Range(zero, end, one), to=dtype)
@@ -573,8 +573,8 @@ def aten_arange(
573573

574574
@torch_op("aten::arange.start", trace_only=True)
575575
def aten_arange_start(
576-
start: float,
577-
end: float,
576+
start: TRealUnlessFloat16OrInt8,
577+
end: TRealUnlessFloat16OrInt8,
578578
dtype: int = -1,
579579
layout: str = "",
580580
device: str = "",
@@ -583,12 +583,8 @@ def aten_arange_start(
583583
"""arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
584584

585585
if dtype == -1 or dtype is None:
586-
if isinstance(start, int) and isinstance(end, int):
587-
result = op.Range(start, end, 1)
588-
else:
589-
start = float(start)
590-
end = float(end)
591-
result = op.Range(start, end, 1.0)
586+
one = op.CastLike(1.0, end)
587+
result = op.Range(start, end, one)
592588
elif _range_supported(dtype):
593589
end = op.Cast(end, to=dtype)
594590
start = op.Cast(start, to=dtype)
@@ -599,46 +595,78 @@ def aten_arange_start(
599595
# because the input dtype may be e.g. bfloat16 / int8 etc.
600596
# which Range does not support. The output type is ensured because the output
601597
# is casted to the specified dtype.
602-
end = op.Constant(value_float=float(end))
603-
start = op.Constant(value_float=float(start))
598+
end = op.Cast(end, to=FLOAT.dtype)
599+
start = op.Cast(start, to=FLOAT.dtype)
604600
one = op.Constant(value_float=1.0)
605601
result = op.Cast(op.Range(start, end, one), to=dtype)
606602

607603
return result
608604

609605

610606
def _adjust_args_for_arange_int_dtype(
611-
start: float,
612-
end: float,
613-
step: float,
614-
) -> Tuple[float, float, float]:
615-
if start < 0:
616-
start = math.ceil(start)
617-
if step < 0:
618-
start = math.floor(start)
607+
start: TRealUnlessFloat16OrInt8,
608+
end: TRealUnlessFloat16OrInt8,
609+
step: TRealUnlessFloat16OrInt8,
610+
) -> Tuple[FLOAT, FLOAT, FLOAT]:
611+
zero = op.Cast(0.0, to=FLOAT.dtype)
612+
start = op.Cast(start, to=FLOAT.dtype)
613+
end = op.Cast(end, to=FLOAT.dtype)
614+
step = op.Cast(step, to=FLOAT.dtype)
619615

620-
return float(start), float(end), float(step)
616+
start = op.Where(op.Less(start, zero), op.Ceil(start), start)
617+
start = op.Where(op.Less(step, zero), op.Floor(start), start)
618+
619+
return (start, end, step)
621620

622621

623622
@torch_op("aten::arange.start_step", trace_only=True)
624623
def aten_arange_start_step(
625-
start: float,
626-
end: float,
627-
step: float = 1.0,
624+
start: TRealUnlessFloat16OrInt8,
625+
end: TRealUnlessFloat16OrInt8,
626+
step: TRealUnlessFloat16OrInt8 = 1.0,
628627
dtype: int = -1,
629628
layout: str = "",
630629
device: str = "",
631630
pin_memory: bool = False,
632631
) -> TensorType:
633632
"""arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
634633

635-
if dtype == -1 or dtype is None:
636-
if isinstance(start, int) and isinstance(end, int):
637-
result = op.Range(start, end, int(step))
634+
if dtype == -1:
635+
# TODO: Because this is a trace_only function, the inputs are not promoted to
636+
# Tensor until it hits ONNX ops. However, if it's dynamic, it should be
637+
# Tensor at this point.
638+
# https://github.com/microsoft/onnxscript/issues/1914
639+
if isinstance(start, (int, float)):
640+
start_is_int = isinstance(start, int)
638641
else:
639-
start = float(start)
640-
end = float(end)
641-
step = float(step)
642+
start_is_int = start.dtype in {
643+
INT16.dtype,
644+
INT32.dtype,
645+
INT64.dtype,
646+
}
647+
if isinstance(end, (int, float)):
648+
end_is_int = isinstance(end, int)
649+
else:
650+
end_is_int = end.dtype in {
651+
INT16.dtype,
652+
INT32.dtype,
653+
INT64.dtype,
654+
}
655+
if isinstance(step, (int, float)):
656+
step_is_int = isinstance(step, int)
657+
else:
658+
step_is_int = step.dtype in {
659+
INT16.dtype,
660+
INT32.dtype,
661+
INT64.dtype,
662+
}
663+
if start_is_int and end_is_int and step_is_int:
664+
result = op.Range(start, end, step)
665+
else:
666+
# to float
667+
start = op.Cast(start, to=FLOAT.dtype)
668+
end = op.Cast(end, to=FLOAT.dtype)
669+
step = op.Cast(step, to=FLOAT.dtype)
642670
result = op.Range(start, end, step)
643671
elif _integral_to_be_adjusted(dtype):
644672
# PyTorch arange op handles these integral types differently from INT64,
@@ -647,18 +675,18 @@ def aten_arange_start_step(
647675
start, end, step = _adjust_args_for_arange_int_dtype(start, end, step)
648676
result = op.Cast(op.Range(start, end, step), to=dtype)
649677
elif dtype == INT64.dtype:
650-
end = int(end)
651-
start = int(start)
652-
step = int(step)
678+
end = op.Cast(end, to=dtype)
679+
start = op.Cast(start, to=dtype)
680+
step = op.Cast(step, to=dtype)
653681
result = op.Range(start, end, step)
654682
else:
655683
# Cast input to float if dtype is not supported by Range,
656684
# because the input dtype may be e.g. bfloat16,
657685
# which Range does not support. The output type is ensured because the output
658686
# is casted to the specified dtype.
659-
end = float(end)
660-
start = float(start)
661-
step = float(step)
687+
end = op.Cast(end, to=FLOAT.dtype)
688+
start = op.Cast(start, to=FLOAT.dtype)
689+
step = op.Cast(step, to=FLOAT.dtype)
662690
result = op.Cast(op.Range(start, end, step), to=dtype)
663691

664692
return result
@@ -4735,8 +4763,8 @@ def aten_linear_backward(
47354763

47364764
@torch_op("aten::linspace", trace_only=True)
47374765
def aten_linspace(
4738-
start: float,
4739-
end: float,
4766+
start: TFloat,
4767+
end: TFloat,
47404768
steps: int,
47414769
dtype: int = FLOAT.dtype,
47424770
layout: str = "",
@@ -4754,7 +4782,6 @@ def aten_linspace(
47544782
if steps == 1:
47554783
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)
47564784

4757-
# TODO(justinchuby): Simplify the logic knowing start and end are floats
47584785
rg = aten_arange_start(0, steps, dtype=dtype)
47594786
start = op.Cast(start, to=dtype)
47604787
end = op.Cast(end, to=dtype)

onnxscript/ir/_core.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
_enums.DataType.FLOAT8E5M2FNUZ,
7171
_enums.DataType.INT4,
7272
_enums.DataType.UINT4,
73+
_enums.DataType.FLOAT4E2M1,
7374
)
7475
)
7576

@@ -182,7 +183,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
182183
When the dtype is not one of the numpy native dtypes, the value needs need to be:
183184
184185
- ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
185-
- ``uint8`` for uint4.
186+
- ``uint8`` for uint4 or float4.
186187
- ``uint8`` for 8-bit data types.
187188
- ``uint16`` for bfloat16
188189
@@ -213,6 +214,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
213214
raise TypeError(
214215
f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
215216
)
217+
if dtype == _enums.DataType.FLOAT4E2M1:
218+
if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
219+
raise TypeError(
220+
f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
221+
)
216222
return
217223

218224
try:
@@ -256,6 +262,8 @@ def _maybe_view_np_array_with_ml_dtypes(
256262
return array.view(ml_dtypes.int4)
257263
if dtype == _enums.DataType.UINT4:
258264
return array.view(ml_dtypes.uint4)
265+
if dtype == _enums.DataType.FLOAT4E2M1:
266+
return array.view(ml_dtypes.float4_e2m1fn)
259267
return array
260268

261269

@@ -431,7 +439,11 @@ def tobytes(self) -> bytes:
431439
"""
432440
# TODO(justinchuby): Support DLPack
433441
array = self.numpy()
434-
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
442+
if self.dtype in {
443+
_enums.DataType.INT4,
444+
_enums.DataType.UINT4,
445+
_enums.DataType.FLOAT4E2M1,
446+
}:
435447
# Pack the array into int4
436448
array = _type_casting.pack_int4(array)
437449
else:
@@ -609,7 +621,11 @@ def _load(self):
609621
)
610622
# Handle the byte order correctly by always using little endian
611623
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
612-
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
624+
if self.dtype in {
625+
_enums.DataType.INT4,
626+
_enums.DataType.UINT4,
627+
_enums.DataType.FLOAT4E2M1,
628+
}:
613629
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
614630
dt = np.dtype(np.uint8).newbyteorder("<")
615631
count = self.size // 2 + self.size % 2
@@ -622,6 +638,8 @@ def _load(self):
622638
self._array = _type_casting.unpack_int4(self._array, shape)
623639
elif self.dtype == _enums.DataType.UINT4:
624640
self._array = _type_casting.unpack_uint4(self._array, shape)
641+
elif self.dtype == _enums.DataType.FLOAT4E2M1:
642+
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
625643
else:
626644
self._array = self._array.reshape(shape)
627645

onnxscript/ir/_core_test.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_init_requires_type_when_value_is_not_np_array(self):
5555
("int4", np.int8, ir.DataType.INT4),
5656
("int4_uint8", np.uint8, ir.DataType.INT4),
5757
("uint4", np.uint8, ir.DataType.UINT4),
58+
("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1),
5859
]
5960
)
6061
def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: ir.DataType):
@@ -131,34 +132,48 @@ def test_tobytes(self):
131132
tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)
132133
self.assertEqual(tensor.tobytes(), array.tobytes())
133134

134-
def test_tobtyes_returns_packed_data_for_int4(self):
135+
def test_tobytes_returns_packed_data_for_int4(self):
135136
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8)
136137
# Test odd sized array
137138
assert len(array) % 2 == 1
138139
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
139140
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")
140141

141-
def test_tobtyes_returns_packed_data_for_int4_ml_dtypes(self):
142+
def test_tobytes_returns_packed_data_for_int4_ml_dtypes(self):
142143
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=ml_dtypes.int4)
143144
# Test odd sized array
144145
assert len(array) % 2 == 1
145146
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
146147
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")
147148

148-
def test_tobtyes_returns_packed_data_for_uint4(self):
149+
def test_tobytes_returns_packed_data_for_uint4(self):
149150
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
150151
# Test odd sized array
151152
assert len(array) % 2 == 1
152153
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
153154
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
154155

155-
def test_tobtyes_returns_packed_data_for_uint4_ml_dtypes(self):
156+
def test_tobytes_returns_packed_data_for_uint4_ml_dtypes(self):
156157
array = np.array([0, 1, 2, 7, 15], dtype=ml_dtypes.uint4)
157158
# Test odd sized array
158159
assert len(array) % 2 == 1
159160
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
160161
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
161162

163+
def test_tobytes_returns_packed_data_for_float4e2m1(self):
164+
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
165+
# Test odd sized array
166+
assert len(array) % 2 == 1
167+
tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1)
168+
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
169+
170+
def test_tobytes_returns_packed_data_for_float4e2m1_ml_dtypes(self):
171+
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
172+
# Test odd sized array
173+
assert len(array) % 2 == 1
174+
tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1)
175+
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
176+
162177
def test_metadata(self):
163178
array = np.random.rand(1, 2).astype(np.float32)
164179
tensor = _core.Tensor(array)
@@ -444,6 +459,19 @@ def test_external_tensor_complex(self, _: str, np_dtype: np.dtype):
444459
# about permission errors
445460
del tensor
446461

462+
def test_external_tensor_float4e2m1(self):
463+
expected_array = np.array([0, 1, 2, 7, 15]).view(ml_dtypes.float4_e2m1fn)
464+
tensor_proto = ir.serde.serialize_tensor(
465+
ir.Tensor(expected_array, dtype=ir.DataType.FLOAT4E2M1)
466+
)
467+
with tempfile.TemporaryDirectory() as temp_dir:
468+
_to_external_tensor(tensor_proto, temp_dir, "tensor.bin")
469+
tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir)
470+
np.testing.assert_array_equal(tensor.numpy(), expected_array)
471+
# Close the mmap file by deleting the reference to tensor so Windows doesn't complain
472+
# about permission errors
473+
del tensor
474+
447475
def test_external_tensor_empty_tensor(self):
448476
expected_array = np.array([], dtype=np.float32)
449477
tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array))

onnxscript/ir/_enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class DataType(enum.IntEnum):
6464
FLOAT8E5M2FNUZ = 20
6565
UINT4 = 21
6666
INT4 = 22
67+
FLOAT4E2M1 = 23
6768

6869
@classmethod
6970
def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -121,6 +122,7 @@ def __str__(self) -> str:
121122
DataType.FLOAT8E5M2FNUZ: 1,
122123
DataType.UINT4: 0.5,
123124
DataType.INT4: 0.5,
125+
DataType.FLOAT4E2M1: 0.5,
124126
}
125127

126128

@@ -150,5 +152,12 @@ def __str__(self) -> str:
150152
np.dtype(ml_dtypes.uint4): DataType.UINT4,
151153
}
152154

155+
# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
156+
_NP_TYPE_TO_DATA_TYPE.update(
157+
{np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
158+
if hasattr(ml_dtypes, "float4_e2m1fn")
159+
else {}
160+
)
161+
153162
# ONNX DataType to Numpy dtype.
154163
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}

0 commit comments

Comments
 (0)