Skip to content

Commit 8f238ca

Browse files
Merged PR 546 with this
1 parent 27a8b8e commit 8f238ca

File tree

6 files changed

+133
-131
lines changed

6 files changed

+133
-131
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ class QEFFBaseModel(ABC):
4747
"""
4848

4949
_pytorch_transforms: List[PytorchTransform]
50-
_onnx_transforms: List[OnnxTransform]
50+
_onnx_transforms = ["FP16ClipTransform", "SplitTensorsTransform"]
5151

5252
@classmethod
5353
def _transform_names(cls) -> List[str]:
54-
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
54+
pytorch_names = [x.__name__ for x in cls._pytorch_transforms]
55+
return pytorch_names + cls._onnx_transforms
5556

5657
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
5758
super().__init__()
@@ -321,9 +322,10 @@ def _export(
321322
}
322323
if onnx_transform_kwargs is not None:
323324
transform_kwargs.update(onnx_transform_kwargs)
324-
325-
for transform in self._onnx_transforms:
326-
model, transformed = transform.apply(model, **transform_kwargs)
325+
# import pdb; pdb.set_trace()
326+
transform_kwargs["transforms"] = self._onnx_transforms
327+
# for transform in self._onnx_transforms:
328+
model, transformed = OnnxTransform.apply(model, **transform_kwargs)
327329

328330
model.metadata_props.append(
329331
onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names()))

QEfficient/base/onnx_transforms.py

Lines changed: 91 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,29 @@
77

88
import gc
99
import logging
10-
from typing import Optional, Tuple
10+
import os
11+
import warnings
12+
from collections import namedtuple
13+
from concurrent.futures import ThreadPoolExecutor
14+
from typing import List, Optional, Tuple
1115

1216
import numpy as np
13-
from onnx import ModelProto, external_data_helper, numpy_helper
17+
from onnx import ModelProto, TensorProto, external_data_helper, numpy_helper
1418

1519
from QEfficient.utils.constants import ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL
1620

1721
logger = logging.getLogger(__name__)
1822

1923

20-
class OnnxTransform:
24+
class BaseOnnxTransform:
2125
"""
2226
OnnxTransform is the base class for graph modifications on exported onnx.
2327
"""
2428

2529
_external_data_loaded_cache = {} # Dict[int, bool]
2630

2731
def __init__(self):
28-
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")
32+
raise TypeError("Transform classes are not to be instantiated. Use the `apply` method directly.")
2933

3034
@classmethod
3135
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
@@ -47,15 +51,11 @@ def _check_external_data_loaded(cls, model: ModelProto) -> bool:
4751
:param model: The ONNX model to check
4852
:returns: True if external data is already loaded, False otherwise
4953
"""
50-
# Use object ID as key instead of the object itself
5154
model_id = id(model)
52-
# Return cached result if available
5355
if model_id in cls._external_data_loaded_cache:
5456
return cls._external_data_loaded_cache[model_id]
5557

56-
# Load the model if not already loaded
5758
for tensor in external_data_helper._get_all_tensors(model):
58-
# Check if tensor has external data but no raw data loaded
5959
if len(tensor.external_data) > 0 and not tensor.HasField("raw_data"):
6060
cls._external_data_loaded_cache[model_id] = False
6161
return False
@@ -77,6 +77,13 @@ def _load_external_data(cls, model: ModelProto, onnx_base_dir: Optional[str] = N
7777
else:
7878
logger.info("External data already loaded (or cached). Skipping bulk load.")
7979

80+
@classmethod
81+
def _cleanup_memory(cls):
82+
"""
83+
Force garbage collection to free up memory after tensor processing.
84+
"""
85+
gc.collect()
86+
8087
@classmethod
8188
def _cleanup_external_data_and_cache(cls, model: ModelProto):
8289
"""
@@ -94,108 +101,99 @@ def _cleanup_external_data_and_cache(cls, model: ModelProto):
94101

95102
logger.info("External data and cache cleaned up.")
96103

97-
@classmethod
98-
def _cleanup_memory(cls):
99-
"""
100-
Force garbage collection to free up memory after tensor processing.
101-
"""
102-
gc.collect()
103-
104-
105-
class FP16ClipTransform(OnnxTransform):
106-
"""
107-
Clips the tensor values to be in FP16 range, but preserves -inf values.
108-
"""
109104

105+
class OnnxTransform(BaseOnnxTransform):
110106
@classmethod
111-
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]:
112-
"""
113-
:param onnx_base_dir: Base directory to load tensors
114-
"""
107+
def apply(
108+
cls,
109+
model: ModelProto,
110+
*,
111+
transforms: List[str],
112+
model_name: str = "",
113+
onnx_base_dir: Optional[str] = None,
114+
file_chunk_size: int = 10 * 2**30,
115+
size_threshold: int = 1024,
116+
**kwargs,
117+
) -> Tuple[ModelProto, bool]:
118+
if len(transforms) == 0:
119+
warnings.warn("Transform is empty. Skipping transformation.")
120+
return model, False
121+
115122
try:
116-
# --- FIX: Ensure external data is loaded efficiently BEFORE processing ---
117123
cls._load_external_data(model, onnx_base_dir)
124+
tensors = external_data_helper._get_all_tensors(model)
118125

119-
finfo = np.finfo(np.float16)
120-
fp16_max = finfo.max
121-
fp16_min = finfo.min
122-
transformed = False
126+
TensorInfo = namedtuple("TensorInfo", ["tensor", "tsize"])
127+
tensor_infos = [
128+
TensorInfo(tensor, len(tensor.raw_data) if tensor.HasField("raw_data") else 0) for tensor in tensors
129+
]
123130

124-
processed_count = 0
125-
for tensor in external_data_helper._get_all_tensors(model):
126-
nptensor = numpy_helper.to_array(tensor) # Removed onnx_base_dir as data is already loaded
127-
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
128-
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
129-
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)
131+
fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
132+
file_num_tracker = {"num": 0, "size": 0}
130133

131-
# Restore -inf values
132-
if neg_inf_mask.any():
133-
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)
134+
# Track which transforms were requested and which were actually applied
135+
requested_transforms = set(transforms)
136+
applied_transforms = {name: False for name in requested_transforms}
134137

135-
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
136-
tensor.CopyFrom(new_tensor)
137-
transformed = True
138+
def process_tensor(index_info: Tuple[int, TensorInfo]) -> List[str]:
139+
idx, info = index_info
140+
tensor, tsize = info
138141

139-
del neg_inf_mask, clipped_tensor, new_tensor
142+
local_applied = []
140143

141-
del nptensor
142-
processed_count += 1
144+
if "FP16ClipTransform" in requested_transforms:
145+
if cls._clip_tensor(tensor, fp16_min, fp16_max):
146+
local_applied.append("FP16ClipTransform")
143147

144-
if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0:
145-
cls._cleanup_memory()
148+
if "SplitTensorsTransform" in requested_transforms and tsize > size_threshold:
149+
if file_num_tracker["size"] + tsize > file_chunk_size:
150+
file_num_tracker["num"] += 1
151+
file_num_tracker["size"] = tsize
152+
else:
153+
file_num_tracker["size"] += tsize
146154

147-
return model, transformed
148-
finally:
149-
# Ensure cleanup happens even if an exception occurs
150-
cls._cleanup_memory()
155+
cls._split_tensor(tensor, model_name, file_num_tracker["num"])
156+
local_applied.append("SplitTensorsTransform")
151157

158+
if (idx + 1) % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0:
159+
cls._cleanup_memory()
152160

153-
class SplitTensorsTransform(OnnxTransform):
154-
"""
155-
Split external tensors file
156-
"""
161+
return local_applied
157162

158-
@classmethod
159-
def apply(
160-
cls,
161-
model: ModelProto,
162-
*,
163-
model_name: str,
164-
onnx_base_dir: Optional[str] = None,
165-
file_chunk_size: int = 10 * 2**30, # 10 GiB
166-
size_threshold: int = 1024,
167-
**kwargs,
168-
) -> Tuple[ModelProto, bool]:
169-
"""
170-
:param model_name: Used for naming external files. i.e. {model_name}_0.onnx.data
171-
:param onnx_base_dir: Base directory to load tensors (if not already loaded).
172-
:param file_chunk_size: Chunk size to split external files into.
173-
:param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally.
174-
"""
175-
try:
176-
file_num = 0
177-
current_file_size = 0
178-
transformed = False
163+
with ThreadPoolExecutor(max_workers=os.cpu_count() * 4) as executor:
164+
results = list(executor.map(process_tensor, enumerate(tensor_infos)))
179165

180-
# --- Adjustment: The initial check and load will now use the new bulk loader ---
181-
# This will either use the cache (if FP16ClipTransform loaded it) or perform the bulk load itself.
182-
cls._load_external_data(model, onnx_base_dir)
166+
for result in results:
167+
for transform_name in result:
168+
applied_transforms[transform_name] = True
183169

184-
processed_count = 0
185-
for tensor in external_data_helper._get_all_tensors(model):
186-
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold):
187-
transformed = True
188-
current_file_size += tsize
189-
if current_file_size > file_chunk_size:
190-
file_num += 1
191-
current_file_size = tsize
192-
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
193-
194-
processed_count += 1
195-
if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0:
196-
cls._cleanup_memory()
170+
for name in requested_transforms:
171+
if applied_transforms[name]:
172+
logger.info(f"Transform '{name}' was applied.")
173+
else:
174+
logger.warning(f"Transform '{name}' was requested but not applied.")
175+
176+
return model, any(applied_transforms.values())
197177

198-
return model, transformed
199178
finally:
200-
# Ensure cleanup happens even if an exception occurs
201179
cls._cleanup_memory()
180+
181+
@staticmethod
182+
def _clip_tensor(tensor, fp16_min, fp16_max) -> bool:
183+
if tensor.data_type != TensorProto.FLOAT:
184+
return False
185+
186+
nptensor = numpy_helper.to_array(tensor)
187+
if np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min):
188+
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
189+
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)
190+
if neg_inf_mask.any():
191+
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)
192+
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
193+
tensor.CopyFrom(new_tensor)
194+
return True
195+
return False
196+
197+
@staticmethod
198+
def _split_tensor(tensor, model_name: str, file_num: int) -> None:
199+
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")

QEfficient/exporter/export_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from onnx import external_data_helper
1919

20-
from QEfficient.base.onnx_transforms import FP16ClipTransform
20+
from QEfficient.base.onnx_transforms import OnnxTransform
2121

2222

2323
def export_onnx(
@@ -218,7 +218,13 @@ def fix_onnx_fp16(
218218
:str: Updated base name of exported ONNX model.
219219
"""
220220
model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
221-
model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path)
221+
if "model" in locals():
222+
OnnxTransform._cleanup_external_data_and_cache(gen_models_path)
223+
OnnxTransform._cleanup_memory()
224+
225+
model, fp16_fix = OnnxTransform.apply(
226+
model, model_name="", onnx_base_dir=gen_models_path, transforms=["FP16ClipTransform"]
227+
)
222228

223229
if fp16_fix:
224230
# Save FP16 model

QEfficient/peft/auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers.generation.streamers import BaseStreamer
1919

2020
from QEfficient.base.modeling_qeff import QEFFBaseModel
21-
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
21+
from QEfficient.base.onnx_transforms import BaseOnnxTransform, OnnxTransform
2222
from QEfficient.base.pytorch_transforms import PytorchTransform
2323
from QEfficient.generation.cloud_infer import QAICInferenceSession
2424
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
@@ -58,7 +58,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel):
5858
"""
5959

6060
_pytorch_transforms: List[PytorchTransform] = [CustomOpsTransform, KVCacheTransform, PeftModelInputsTransform]
61-
_onnx_transforms: List[OnnxTransform] = [FP16ClipTransform, AdapterWeightsToInputsTransform, SplitTensorsTransform]
61+
_onnx_transforms: List[BaseOnnxTransform] = [OnnxTransform, AdapterWeightsToInputsTransform]
6262
_hf_auto_class = AutoPeftModelForCausalLM
6363

6464
def __init__(self, model: nn.Module):

0 commit comments

Comments
 (0)