7
7
8
8
import gc
9
9
import logging
10
- from typing import Optional , Tuple
10
+ import os
11
+ import warnings
12
+ from collections import namedtuple
13
+ from concurrent .futures import ThreadPoolExecutor
14
+ from typing import List , Optional , Tuple
11
15
12
16
import numpy as np
13
- from onnx import ModelProto , external_data_helper , numpy_helper
17
+ from onnx import ModelProto , TensorProto , external_data_helper , numpy_helper
14
18
15
19
from QEfficient .utils .constants import ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL
16
20
17
21
logger = logging .getLogger (__name__ )
18
22
19
23
20
- class OnnxTransform :
24
+ class BaseOnnxTransform :
21
25
"""
22
26
OnnxTransform is the base class for graph modifications on exported onnx.
23
27
"""
24
28
25
29
_external_data_loaded_cache = {} # Dict[int, bool]
26
30
27
31
def __init__ (self ):
28
- raise TypeError ("Transform classes are not to be instantiated. Directly use the `apply` method." )
32
+ raise TypeError ("Transform classes are not to be instantiated. Use the `apply` method directly ." )
29
33
30
34
@classmethod
31
35
def apply (cls , model : ModelProto , ** kwargs ) -> Tuple [ModelProto , bool ]:
@@ -47,15 +51,11 @@ def _check_external_data_loaded(cls, model: ModelProto) -> bool:
47
51
:param model: The ONNX model to check
48
52
:returns: True if external data is already loaded, False otherwise
49
53
"""
50
- # Use object ID as key instead of the object itself
51
54
model_id = id (model )
52
- # Return cached result if available
53
55
if model_id in cls ._external_data_loaded_cache :
54
56
return cls ._external_data_loaded_cache [model_id ]
55
57
56
- # Load the model if not already loaded
57
58
for tensor in external_data_helper ._get_all_tensors (model ):
58
- # Check if tensor has external data but no raw data loaded
59
59
if len (tensor .external_data ) > 0 and not tensor .HasField ("raw_data" ):
60
60
cls ._external_data_loaded_cache [model_id ] = False
61
61
return False
@@ -77,6 +77,13 @@ def _load_external_data(cls, model: ModelProto, onnx_base_dir: Optional[str] = N
77
77
else :
78
78
logger .info ("External data already loaded (or cached). Skipping bulk load." )
79
79
80
+ @classmethod
81
+ def _cleanup_memory (cls ):
82
+ """
83
+ Force garbage collection to free up memory after tensor processing.
84
+ """
85
+ gc .collect ()
86
+
80
87
@classmethod
81
88
def _cleanup_external_data_and_cache (cls , model : ModelProto ):
82
89
"""
@@ -94,108 +101,99 @@ def _cleanup_external_data_and_cache(cls, model: ModelProto):
94
101
95
102
logger .info ("External data and cache cleaned up." )
96
103
97
- @classmethod
98
- def _cleanup_memory (cls ):
99
- """
100
- Force garbage collection to free up memory after tensor processing.
101
- """
102
- gc .collect ()
103
-
104
-
105
- class FP16ClipTransform (OnnxTransform ):
106
- """
107
- Clips the tensor values to be in FP16 range, but preserves -inf values.
108
- """
109
104
105
+ class OnnxTransform (BaseOnnxTransform ):
110
106
@classmethod
111
- def apply (cls , model : ModelProto , * , onnx_base_dir : Optional [str ] = None , ** kwargs ) -> Tuple [ModelProto , bool ]:
112
- """
113
- :param onnx_base_dir: Base directory to load tensors
114
- """
107
+ def apply (
108
+ cls ,
109
+ model : ModelProto ,
110
+ * ,
111
+ transforms : List [str ],
112
+ model_name : str = "" ,
113
+ onnx_base_dir : Optional [str ] = None ,
114
+ file_chunk_size : int = 10 * 2 ** 30 ,
115
+ size_threshold : int = 1024 ,
116
+ ** kwargs ,
117
+ ) -> Tuple [ModelProto , bool ]:
118
+ if len (transforms ) == 0 :
119
+ warnings .warn ("Transform is empty. Skipping transformation." )
120
+ return model , False
121
+
115
122
try :
116
- # --- FIX: Ensure external data is loaded efficiently BEFORE processing ---
117
123
cls ._load_external_data (model , onnx_base_dir )
124
+ tensors = external_data_helper ._get_all_tensors (model )
118
125
119
- finfo = np . finfo ( np . float16 )
120
- fp16_max = finfo . max
121
- fp16_min = finfo . min
122
- transformed = False
126
+ TensorInfo = namedtuple ( "TensorInfo" , [ "tensor" , "tsize" ] )
127
+ tensor_infos = [
128
+ TensorInfo ( tensor , len ( tensor . raw_data ) if tensor . HasField ( "raw_data" ) else 0 ) for tensor in tensors
129
+ ]
123
130
124
- processed_count = 0
125
- for tensor in external_data_helper ._get_all_tensors (model ):
126
- nptensor = numpy_helper .to_array (tensor ) # Removed onnx_base_dir as data is already loaded
127
- if nptensor .dtype == np .float32 and (np .any (nptensor > fp16_max ) or np .any (nptensor < fp16_min )):
128
- neg_inf_mask = np .isinf (nptensor ) & (nptensor < 0 )
129
- clipped_tensor = np .clip (nptensor , fp16_min , fp16_max )
131
+ fp16_min , fp16_max = np .finfo (np .float16 ).min , np .finfo (np .float16 ).max
132
+ file_num_tracker = {"num" : 0 , "size" : 0 }
130
133
131
- # Restore -inf values
132
- if neg_inf_mask . any ():
133
- clipped_tensor = np . where ( neg_inf_mask , np . float32 ( "-inf" ), clipped_tensor )
134
+ # Track which transforms were requested and which were actually applied
135
+ requested_transforms = set ( transforms )
136
+ applied_transforms = { name : False for name in requested_transforms }
134
137
135
- new_tensor = numpy_helper . from_array ( clipped_tensor , tensor . name )
136
- tensor . CopyFrom ( new_tensor )
137
- transformed = True
138
+ def process_tensor ( index_info : Tuple [ int , TensorInfo ]) -> List [ str ]:
139
+ idx , info = index_info
140
+ tensor , tsize = info
138
141
139
- del neg_inf_mask , clipped_tensor , new_tensor
142
+ local_applied = []
140
143
141
- del nptensor
142
- processed_count += 1
144
+ if "FP16ClipTransform" in requested_transforms :
145
+ if cls ._clip_tensor (tensor , fp16_min , fp16_max ):
146
+ local_applied .append ("FP16ClipTransform" )
143
147
144
- if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0 :
145
- cls ._cleanup_memory ()
148
+ if "SplitTensorsTransform" in requested_transforms and tsize > size_threshold :
149
+ if file_num_tracker ["size" ] + tsize > file_chunk_size :
150
+ file_num_tracker ["num" ] += 1
151
+ file_num_tracker ["size" ] = tsize
152
+ else :
153
+ file_num_tracker ["size" ] += tsize
146
154
147
- return model , transformed
148
- finally :
149
- # Ensure cleanup happens even if an exception occurs
150
- cls ._cleanup_memory ()
155
+ cls ._split_tensor (tensor , model_name , file_num_tracker ["num" ])
156
+ local_applied .append ("SplitTensorsTransform" )
151
157
158
+ if (idx + 1 ) % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0 :
159
+ cls ._cleanup_memory ()
152
160
153
- class SplitTensorsTransform (OnnxTransform ):
154
- """
155
- Split external tensors file
156
- """
161
+ return local_applied
157
162
158
- @classmethod
159
- def apply (
160
- cls ,
161
- model : ModelProto ,
162
- * ,
163
- model_name : str ,
164
- onnx_base_dir : Optional [str ] = None ,
165
- file_chunk_size : int = 10 * 2 ** 30 , # 10 GiB
166
- size_threshold : int = 1024 ,
167
- ** kwargs ,
168
- ) -> Tuple [ModelProto , bool ]:
169
- """
170
- :param model_name: Used for naming external files. i.e. {model_name}_0.onnx.data
171
- :param onnx_base_dir: Base directory to load tensors (if not already loaded).
172
- :param file_chunk_size: Chunk size to split external files into.
173
- :param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally.
174
- """
175
- try :
176
- file_num = 0
177
- current_file_size = 0
178
- transformed = False
163
+ with ThreadPoolExecutor (max_workers = os .cpu_count () * 4 ) as executor :
164
+ results = list (executor .map (process_tensor , enumerate (tensor_infos )))
179
165
180
- # --- Adjustment: The initial check and load will now use the new bulk loader ---
181
- # This will either use the cache (if FP16ClipTransform loaded it) or perform the bulk load itself.
182
- cls . _load_external_data ( model , onnx_base_dir )
166
+ for result in results :
167
+ for transform_name in result :
168
+ applied_transforms [ transform_name ] = True
183
169
184
- processed_count = 0
185
- for tensor in external_data_helper ._get_all_tensors (model ):
186
- if tensor .HasField ("raw_data" ) and ((tsize := len (tensor .raw_data )) > size_threshold ):
187
- transformed = True
188
- current_file_size += tsize
189
- if current_file_size > file_chunk_size :
190
- file_num += 1
191
- current_file_size = tsize
192
- external_data_helper .set_external_data (tensor , f"{ model_name } _{ file_num } .onnx.data" )
193
-
194
- processed_count += 1
195
- if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0 :
196
- cls ._cleanup_memory ()
170
+ for name in requested_transforms :
171
+ if applied_transforms [name ]:
172
+ logger .info (f"Transform '{ name } ' was applied." )
173
+ else :
174
+ logger .warning (f"Transform '{ name } ' was requested but not applied." )
175
+
176
+ return model , any (applied_transforms .values ())
197
177
198
- return model , transformed
199
178
finally :
200
- # Ensure cleanup happens even if an exception occurs
201
179
cls ._cleanup_memory ()
180
+
181
+ @staticmethod
182
+ def _clip_tensor (tensor , fp16_min , fp16_max ) -> bool :
183
+ if tensor .data_type != TensorProto .FLOAT :
184
+ return False
185
+
186
+ nptensor = numpy_helper .to_array (tensor )
187
+ if np .any (nptensor > fp16_max ) or np .any (nptensor < fp16_min ):
188
+ neg_inf_mask = np .isinf (nptensor ) & (nptensor < 0 )
189
+ clipped_tensor = np .clip (nptensor , fp16_min , fp16_max )
190
+ if neg_inf_mask .any ():
191
+ clipped_tensor = np .where (neg_inf_mask , np .float32 ("-inf" ), clipped_tensor )
192
+ new_tensor = numpy_helper .from_array (clipped_tensor , tensor .name )
193
+ tensor .CopyFrom (new_tensor )
194
+ return True
195
+ return False
196
+
197
+ @staticmethod
198
+ def _split_tensor (tensor , model_name : str , file_num : int ) -> None :
199
+ external_data_helper .set_external_data (tensor , f"{ model_name } _{ file_num } .onnx.data" )
0 commit comments