@@ -55,6 +55,7 @@ def test_init_requires_type_when_value_is_not_np_array(self):
5555 ("int4" , np .int8 , ir .DataType .INT4 ),
5656 ("int4_uint8" , np .uint8 , ir .DataType .INT4 ),
5757 ("uint4" , np .uint8 , ir .DataType .UINT4 ),
58+ ("float4e2m1" , np .uint8 , ir .DataType .FLOAT4E2M1 ),
5859 ]
5960 )
6061 def test_init_with_non_native_numpy_dtype (self , _ : str , np_dtype , dtype : ir .DataType ):
@@ -131,34 +132,48 @@ def test_tobytes(self):
131132 tensor = _core .Tensor (torch_tensor , dtype = ir .DataType .FLOAT )
132133 self .assertEqual (tensor .tobytes (), array .tobytes ())
133134
134- def test_tobtyes_returns_packed_data_for_int4 (self ):
135+ def test_tobytes_returns_packed_data_for_int4 (self ):
135136 array = np .array ([- 8 , - 1 , 0 , 1 , 2 , 7 , 1 ], dtype = np .int8 )
136137 # Test odd sized array
137138 assert len (array ) % 2 == 1
138139 tensor = _core .Tensor (array , dtype = ir .DataType .INT4 )
139140 self .assertEqual (tensor .tobytes (), b"\xf8 \x10 r\x01 " )
140141
141- def test_tobtyes_returns_packed_data_for_int4_ml_dtypes (self ):
142+ def test_tobytes_returns_packed_data_for_int4_ml_dtypes (self ):
142143 array = np .array ([- 8 , - 1 , 0 , 1 , 2 , 7 , 1 ], dtype = ml_dtypes .int4 )
143144 # Test odd sized array
144145 assert len (array ) % 2 == 1
145146 tensor = _core .Tensor (array , dtype = ir .DataType .INT4 )
146147 self .assertEqual (tensor .tobytes (), b"\xf8 \x10 r\x01 " )
147148
148- def test_tobtyes_returns_packed_data_for_uint4 (self ):
149+ def test_tobytes_returns_packed_data_for_uint4 (self ):
149150 array = np .array ([0 , 1 , 2 , 7 , 15 ], dtype = np .uint8 )
150151 # Test odd sized array
151152 assert len (array ) % 2 == 1
152153 tensor = _core .Tensor (array , dtype = ir .DataType .UINT4 )
153154 self .assertEqual (tensor .tobytes (), b"\x10 r\x0f " )
154155
155- def test_tobtyes_returns_packed_data_for_uint4_ml_dtypes (self ):
156+ def test_tobytes_returns_packed_data_for_uint4_ml_dtypes (self ):
156157 array = np .array ([0 , 1 , 2 , 7 , 15 ], dtype = ml_dtypes .uint4 )
157158 # Test odd sized array
158159 assert len (array ) % 2 == 1
159160 tensor = _core .Tensor (array , dtype = ir .DataType .UINT4 )
160161 self .assertEqual (tensor .tobytes (), b"\x10 r\x0f " )
161162
163+ def test_tobytes_returns_packed_data_for_float4e2m1 (self ):
164+ array = np .array ([0 , 1 , 2 , 7 , 15 ], dtype = np .uint8 )
165+ # Test odd sized array
166+ assert len (array ) % 2 == 1
167+ tensor = _core .Tensor (array , dtype = ir .DataType .FLOAT4E2M1 )
168+ self .assertEqual (tensor .tobytes (), b"\x10 r\x0f " )
169+
170+ def test_tobytes_returns_packed_data_for_float4e2m1_ml_dtypes (self ):
171+ array = np .array ([0 , 1 , 2 , 7 , 15 ], dtype = np .uint8 )
172+ # Test odd sized array
173+ assert len (array ) % 2 == 1
174+ tensor = _core .Tensor (array , dtype = ir .DataType .FLOAT4E2M1 )
175+ self .assertEqual (tensor .tobytes (), b"\x10 r\x0f " )
176+
162177 def test_metadata (self ):
163178 array = np .random .rand (1 , 2 ).astype (np .float32 )
164179 tensor = _core .Tensor (array )
@@ -444,6 +459,19 @@ def test_external_tensor_complex(self, _: str, np_dtype: np.dtype):
444459 # about permission errors
445460 del tensor
446461
462+ def test_external_tensor_float4e2m1 (self ):
463+ expected_array = np .array ([0 , 1 , 2 , 7 , 15 ]).view (ml_dtypes .float4_e2m1fn )
464+ tensor_proto = ir .serde .serialize_tensor (
465+ ir .Tensor (expected_array , dtype = ir .DataType .FLOAT4E2M1 )
466+ )
467+ with tempfile .TemporaryDirectory () as temp_dir :
468+ _to_external_tensor (tensor_proto , temp_dir , "tensor.bin" )
469+ tensor = ir .serde .deserialize_tensor (tensor_proto , temp_dir )
470+ np .testing .assert_array_equal (tensor .numpy (), expected_array )
471+ # Close the mmap file by deleting the reference to tensor so Windows doesn't complain
472+ # about permission errors
473+ del tensor
474+
447475 def test_external_tensor_empty_tensor (self ):
448476 expected_array = np .array ([], dtype = np .float32 )
449477 tensor_proto = ir .serde .serialize_tensor (ir .Tensor (expected_array ))
0 commit comments