Skip to content

Commit 8295a77

Browse files
authored
Merge branch 'main' into titaiwang/fix_arange
2 parents 0636533 + 3016daa commit 8295a77

File tree

8 files changed

+84
-9
lines changed

8 files changed

+84
-9
lines changed

onnxscript/ir/_core.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
_enums.DataType.FLOAT8E5M2FNUZ,
7171
_enums.DataType.INT4,
7272
_enums.DataType.UINT4,
73+
_enums.DataType.FLOAT4E2M1,
7374
)
7475
)
7576

@@ -182,7 +183,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
182183
When the dtype is not one of the numpy native dtypes, the value needs need to be:
183184
184185
- ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
185-
- ``uint8`` for uint4.
186+
- ``uint8`` for uint4 or float4.
186187
- ``uint8`` for 8-bit data types.
187188
- ``uint16`` for bfloat16
188189
@@ -213,6 +214,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
213214
raise TypeError(
214215
f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
215216
)
217+
if dtype == _enums.DataType.FLOAT4E2M1:
218+
if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
219+
raise TypeError(
220+
f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
221+
)
216222
return
217223

218224
try:
@@ -256,6 +262,8 @@ def _maybe_view_np_array_with_ml_dtypes(
256262
return array.view(ml_dtypes.int4)
257263
if dtype == _enums.DataType.UINT4:
258264
return array.view(ml_dtypes.uint4)
265+
if dtype == _enums.DataType.FLOAT4E2M1:
266+
return array.view(ml_dtypes.float4_e2m1fn)
259267
return array
260268

261269

@@ -431,7 +439,11 @@ def tobytes(self) -> bytes:
431439
"""
432440
# TODO(justinchuby): Support DLPack
433441
array = self.numpy()
434-
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
442+
if self.dtype in {
443+
_enums.DataType.INT4,
444+
_enums.DataType.UINT4,
445+
_enums.DataType.FLOAT4E2M1,
446+
}:
435447
# Pack the array into int4
436448
array = _type_casting.pack_int4(array)
437449
else:
@@ -609,7 +621,11 @@ def _load(self):
609621
)
610622
# Handle the byte order correctly by always using little endian
611623
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
612-
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
624+
if self.dtype in {
625+
_enums.DataType.INT4,
626+
_enums.DataType.UINT4,
627+
_enums.DataType.FLOAT4E2M1,
628+
}:
613629
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
614630
dt = np.dtype(np.uint8).newbyteorder("<")
615631
count = self.size // 2 + self.size % 2
@@ -622,6 +638,8 @@ def _load(self):
622638
self._array = _type_casting.unpack_int4(self._array, shape)
623639
elif self.dtype == _enums.DataType.UINT4:
624640
self._array = _type_casting.unpack_uint4(self._array, shape)
641+
elif self.dtype == _enums.DataType.FLOAT4E2M1:
642+
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
625643
else:
626644
self._array = self._array.reshape(shape)
627645

onnxscript/ir/_core_test.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_init_requires_type_when_value_is_not_np_array(self):
5555
("int4", np.int8, ir.DataType.INT4),
5656
("int4_uint8", np.uint8, ir.DataType.INT4),
5757
("uint4", np.uint8, ir.DataType.UINT4),
58+
("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1),
5859
]
5960
)
6061
def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: ir.DataType):
@@ -131,34 +132,48 @@ def test_tobytes(self):
131132
tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)
132133
self.assertEqual(tensor.tobytes(), array.tobytes())
133134

134-
def test_tobtyes_returns_packed_data_for_int4(self):
135+
def test_tobytes_returns_packed_data_for_int4(self):
135136
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8)
136137
# Test odd sized array
137138
assert len(array) % 2 == 1
138139
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
139140
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")
140141

141-
def test_tobtyes_returns_packed_data_for_int4_ml_dtypes(self):
142+
def test_tobytes_returns_packed_data_for_int4_ml_dtypes(self):
142143
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=ml_dtypes.int4)
143144
# Test odd sized array
144145
assert len(array) % 2 == 1
145146
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
146147
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")
147148

148-
def test_tobtyes_returns_packed_data_for_uint4(self):
149+
def test_tobytes_returns_packed_data_for_uint4(self):
149150
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
150151
# Test odd sized array
151152
assert len(array) % 2 == 1
152153
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
153154
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
154155

155-
def test_tobtyes_returns_packed_data_for_uint4_ml_dtypes(self):
156+
def test_tobytes_returns_packed_data_for_uint4_ml_dtypes(self):
156157
array = np.array([0, 1, 2, 7, 15], dtype=ml_dtypes.uint4)
157158
# Test odd sized array
158159
assert len(array) % 2 == 1
159160
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
160161
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
161162

163+
def test_tobytes_returns_packed_data_for_float4e2m1(self):
164+
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
165+
# Test odd sized array
166+
assert len(array) % 2 == 1
167+
tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1)
168+
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
169+
170+
def test_tobytes_returns_packed_data_for_float4e2m1_ml_dtypes(self):
171+
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
172+
# Test odd sized array
173+
assert len(array) % 2 == 1
174+
tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1)
175+
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
176+
162177
def test_metadata(self):
163178
array = np.random.rand(1, 2).astype(np.float32)
164179
tensor = _core.Tensor(array)
@@ -444,6 +459,19 @@ def test_external_tensor_complex(self, _: str, np_dtype: np.dtype):
444459
# about permission errors
445460
del tensor
446461

462+
def test_external_tensor_float4e2m1(self):
463+
expected_array = np.array([0, 1, 2, 7, 15]).view(ml_dtypes.float4_e2m1fn)
464+
tensor_proto = ir.serde.serialize_tensor(
465+
ir.Tensor(expected_array, dtype=ir.DataType.FLOAT4E2M1)
466+
)
467+
with tempfile.TemporaryDirectory() as temp_dir:
468+
_to_external_tensor(tensor_proto, temp_dir, "tensor.bin")
469+
tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir)
470+
np.testing.assert_array_equal(tensor.numpy(), expected_array)
471+
# Close the mmap file by deleting the reference to tensor so Windows doesn't complain
472+
# about permission errors
473+
del tensor
474+
447475
def test_external_tensor_empty_tensor(self):
448476
expected_array = np.array([], dtype=np.float32)
449477
tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array))

onnxscript/ir/_enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class DataType(enum.IntEnum):
6464
FLOAT8E5M2FNUZ = 20
6565
UINT4 = 21
6666
INT4 = 22
67+
FLOAT4E2M1 = 23
6768

6869
@classmethod
6970
def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -121,6 +122,7 @@ def __str__(self) -> str:
121122
DataType.FLOAT8E5M2FNUZ: 1,
122123
DataType.UINT4: 0.5,
123124
DataType.INT4: 0.5,
125+
DataType.FLOAT4E2M1: 0.5,
124126
}
125127

126128

@@ -150,5 +152,12 @@ def __str__(self) -> str:
150152
np.dtype(ml_dtypes.uint4): DataType.UINT4,
151153
}
152154

155+
# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
156+
_NP_TYPE_TO_DATA_TYPE.update(
157+
{np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
158+
if hasattr(ml_dtypes, "float4_e2m1fn")
159+
else {}
160+
)
161+
153162
# ONNX DataType to Numpy dtype.
154163
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}

onnxscript/ir/_enums_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def test_enums_are_the_same_as_spec(self):
3232
self.assertEqual(_enums.DataType.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E5M2FNUZ)
3333
self.assertEqual(_enums.DataType.UINT4, onnx.TensorProto.UINT4)
3434
self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4)
35+
if hasattr(onnx.TensorProto, "FLOAT4E2M1"):
36+
self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1)
3537
self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED)
3638

3739
def test_from_numpy_takes_np_dtype_and_returns_data_type(self):

onnxscript/ir/_type_casting.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,18 @@ def unpack_int4(
8989
"""
9090
unpacked = _unpack_uint4_as_uint8(data, dims)
9191
return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
92+
93+
94+
def unpack_float4e2m1(
95+
data: npt.NDArray[np.uint8], dims: Sequence[int]
96+
) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
97+
"""Convert a packed float4e2m1 array to unpacked float4e2m1 array.
98+
99+
Args:
100+
data: A numpy array.
101+
dims: The dimensions are used to reshape the unpacked buffer.
102+
103+
Returns:
104+
A numpy array of float32 reshaped to dims.
105+
"""
106+
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)

onnxscript/ir/serde.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ def numpy(self) -> np.ndarray:
323323
return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
324324
elif dtype == _enums.DataType.UINT4:
325325
return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
326+
elif dtype == _enums.DataType.FLOAT4E2M1:
327+
return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
326328
else:
327329
# Otherwise convert to the correct dtype and reshape
328330
# Note we cannot use view() here because the storage dtype may not be the same size as the target
@@ -369,6 +371,7 @@ def tobytes(self) -> bytes:
369371
_enums.DataType.FLOAT8E5M2FNUZ,
370372
_enums.DataType.INT4,
371373
_enums.DataType.UINT4,
374+
_enums.DataType.FLOAT4E2M1,
372375
}:
373376
# uint4 and int4 values are already packed, even when stored as int32
374377
# so we don't need to pack them again
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
onnx-weekly==1.18.0.dev20241014
1+
onnx-weekly==1.18.0.dev20241021

requirements/lintrunner/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is auto updated by dependabot
22
lintrunner-adapters>=0.8.0
33
# RUFF, RUFF-FIX
4-
ruff==0.6.9
4+
ruff==0.7.0
55
# MYPY
66
mypy==1.10.1
77
types-PyYAML==6.0.12.20240808

0 commit comments

Comments
 (0)