Skip to content

Commit 99fc347

Browse files
committed
move weights loading related logic to ModelLoader
Signed-off-by: junq <[email protected]>
1 parent 25389c9 commit 99fc347

File tree

4 files changed

+327
-293
lines changed

4 files changed

+327
-293
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 14 additions & 291 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import bisect
22
import contextlib
3-
import copy
43
import functools
54
import gc
65
import inspect
76
import math
8-
import os
9-
import traceback
107
import weakref
118
from abc import ABC, abstractmethod
129
from contextlib import contextmanager
@@ -17,16 +14,13 @@
1714

1815
import tensorrt_llm.bindings.internal.userbuffers as ub
1916
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
20-
str_dtype_to_torch, torch_dtype_to_str,
21-
trace_func)
17+
torch_dtype_to_str, trace_func)
2218
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
2319
MultimodalRuntimeData)
2420
from tensorrt_llm.logger import logger
2521
from tensorrt_llm.lora_helper import LoraConfig
2622
from tensorrt_llm.lora_manager import LoraModelConfig
2723
from tensorrt_llm.mapping import CpType, Mapping
28-
from tensorrt_llm.models.modeling_utils import QuantAlgo
29-
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
3024

3125
from ..attention_backend.interface import (AttentionMetadata,
3226
AttentionRuntimeFeatures)
@@ -40,13 +34,10 @@
4034
from ..distributed.communicator import init_pp_comm
4135
from ..expert_statistic import ExpertStatistic
4236
from ..metadata import KVCacheParams
43-
from ..model_config import ModelConfig, MoeLoadBalancerConfig
44-
from ..models import AutoModelForCausalLM
4537
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
46-
from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
47-
timing)
48-
from ..modules.fused_moe.moe_load_balancer import (
49-
MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer)
38+
from ..models.modeling_utils import DecoderModelForCausalLM
39+
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
40+
MoeLoadBalancerIterContext)
5041
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
5142
get_spec_metadata,
5243
update_spec_config_from_model_config)
@@ -55,12 +46,13 @@
5546
from ..utils import (get_model_extra_attrs,
5647
set_per_request_piecewise_cuda_graph_flag,
5748
set_torch_compiling, with_model_extra_attrs)
58-
from .config import LoadFormat, PyTorchConfig
49+
from .config import PyTorchConfig
5950
from .config_utils import is_mla
6051
from .cuda_graph_runner import CUDAGraphRunner
6152
from .guided_decoder import CapturableGuidedDecoder
6253
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
6354
from .llm_request import get_draft_token_length
55+
from .model_loader import ModelLoader
6456
from .resource_manager import (BaseResourceManager, KVCacheManager,
6557
ResourceManager, ResourceManagerType)
6658
from .sampler import SampleStateTensors
@@ -95,137 +87,6 @@ def warmup(self, resource_manager: ResourceManager) -> None:
9587
return
9688

9789

98-
_KV_CACHE_MAP = {
99-
"fp8": QuantAlgo.FP8.value,
100-
"nvfp4": QuantAlgo.NVFP4.value,
101-
"auto": "auto"
102-
}
103-
_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto")
104-
105-
106-
def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig,
107-
mamba_ssm_cache_dtype: str) -> None:
108-
if mamba_ssm_cache_dtype == "auto":
109-
mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype
110-
else:
111-
mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype)
112-
113-
config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
114-
115-
116-
def validate_and_set_kv_cache_quant(model_config: ModelConfig,
117-
pyt_kv_cache_dtype: str) -> QuantAlgo:
118-
logger.info(
119-
f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"'
120-
)
121-
# Quantization from hf_quant_config.json
122-
kv_cache_quant = model_config.quant_config.kv_cache_quant_algo
123-
# PyTorch configuration quantization
124-
valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES)
125-
mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None)
126-
127-
# If we're letting the checkpoint dictate the quant with auto, simply
128-
# return and do not modify the checkpoint.
129-
if pyt_kv_cache_dtype == "auto":
130-
logger.info(
131-
f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using '
132-
"checkpoint KV quantization.")
133-
return
134-
135-
# If we have an invalid quantization, simply raise an exception.
136-
if not valid_pyt_quant:
137-
raise ValueError(
138-
"Overriding KV cache quantization with an invalid type "
139-
f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" '
140-
f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".')
141-
142-
# If we get to this point we have a valid quantization setting, but if
143-
# we have an existing setting and it doesn't match we shouldn't proceed.
144-
if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant:
145-
raise RuntimeError(
146-
"Attempting to override KV cache quantization "
147-
f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype='
148-
f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a '
149-
"pre-quantized KV cache that doesn't match.")
150-
151-
# We have an open ended KV cache in the checkpoint
152-
# and we have a specified override.
153-
model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant
154-
155-
156-
def initialize_dummy_weights(
157-
model: torch.nn.Module,
158-
low: float = -1e-3,
159-
high: float = 1e-3,
160-
seed: int = 0,
161-
) -> None:
162-
"""
163-
This is similar to this function in SGLang with a few changes:
164-
https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577
165-
166-
This method is used to initialize weights with dummy values for testing
167-
models without checkpoints. Unquantized (FP16/BF16/etc) values are generated
168-
from a uniform distribution over the interval (low, high).
169-
170-
For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values.
171-
We simply generate values uniformly across an interval that has been empirically verified
172-
to not generate NaNs/inf for these.
173-
"""
174-
175-
def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]:
176-
# These values are not necessarily the largest possible min/max,
177-
# they need to be small enough to avoid NaNs.
178-
if dtype in (torch.float8_e4m3fn, torch.int8):
179-
return (-3.0, 3.0)
180-
181-
elif dtype == float4_e2m1x2:
182-
# These correspond to bits of 2 packed FP4 values.
183-
# Because we only go up to 64, the high 4 bits will
184-
# always be 0. But this is fine - we just need values
185-
# that won't generate NaNs.
186-
return (0, 64)
187-
188-
else:
189-
raise NotImplementedError(f"Unknown quantized type: {dtype}.")
190-
191-
for param in model.state_dict().values():
192-
generator = torch.Generator(device=param.data.device)
193-
generator.manual_seed(seed)
194-
dtype = param.data.dtype
195-
196-
if param.data.element_size() < 2:
197-
# We need to do a cast/round since torch doesn't have uniform_
198-
# support for these dtypes.
199-
tmp_param = torch.empty(param.data.shape,
200-
dtype=torch.float16,
201-
device=param.data.device)
202-
203-
quant_min, quant_max = _get_random_min_max(dtype)
204-
tmp_param = tmp_param.uniform_(quant_min,
205-
quant_max,
206-
generator=generator)
207-
208-
param.data.copy_(tmp_param.to(dtype))
209-
210-
# Note: no need to to mess with int32 params, these are probably
211-
# constants and not weights.
212-
elif torch.is_floating_point(param):
213-
param.uniform_(low, high, generator=generator)
214-
215-
216-
def get_rank_model_storage(model):
217-
total_bytes = 0
218-
for _, param in model.named_parameters():
219-
if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device(
220-
):
221-
total_bytes += param.element_size() * param.nelement()
222-
for _, buf in model.named_buffers():
223-
if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device(
224-
):
225-
total_bytes += buf.element_size() * buf.nelement()
226-
return total_bytes
227-
228-
22990
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
23091
max_batch_size: int, max_num_tokens: int,
23192
max_draft_len: int,
@@ -302,20 +163,17 @@ def __init__(
302163
)
303164

304165
attn_backend = pytorch_backend_config.attn_backend
305-
self.model = self._load_model(
306-
model_path,
166+
loader = ModelLoader(
167+
pytorch_backend_config=pytorch_backend_config,
307168
mapping=self.mapping,
308-
checkpoint_loader=checkpoint_loader,
309-
attn_backend=attn_backend,
310-
moe_backend=pytorch_backend_config.moe_backend,
311-
moe_disable_finalize_fusion=pytorch_backend_config.
312-
moe_disable_finalize_fusion,
313-
load_format=pytorch_backend_config.load_format,
169+
spec_config=self.spec_config,
314170
max_num_tokens=max_num_tokens,
315-
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
316-
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
171+
max_seq_len=max_seq_len,
317172
lora_config=lora_config,
318-
drafting_loop_wrapper=drafting_loop_wrapper)
173+
)
174+
self.model = loader.load(checkpoint_dir=model_path,
175+
checkpoint_loader=checkpoint_loader,
176+
drafting_loop_wrapper=drafting_loop_wrapper)
319177
# In case that some tests use stub models and override `_load_model`.
320178
if not hasattr(self.model, 'extra_attrs'):
321179
self.model.extra_attrs = {}
@@ -944,141 +802,6 @@ def __del__(self) -> None:
944802
# Release model weights.
945803
release_gc()
946804

947-
def _load_model(self,
948-
checkpoint_dir: str,
949-
checkpoint_loader: BaseCheckpointLoader,
950-
load_format: LoadFormat,
951-
max_num_tokens: int,
952-
moe_max_num_tokens: Optional[int] = None,
953-
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
954-
lora_config: Optional[LoraConfig] = None,
955-
drafting_loop_wrapper: Optional[Callable[
956-
[torch.nn.Module], torch.nn.Module]] = None,
957-
**kwargs) -> DecoderModelForCausalLM:
958-
config = checkpoint_loader.load_config(
959-
checkpoint_dir,
960-
trust_remote_code=True,
961-
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
962-
use_cuda_graph=self.pytorch_backend_config.use_cuda_graph,
963-
force_dynamic_quantization=self.pytorch_backend_config.
964-
force_dynamic_quantization,
965-
spec_config=self.spec_config,
966-
max_num_tokens=max_num_tokens,
967-
max_seq_len=self.max_seq_len,
968-
moe_max_num_tokens=moe_max_num_tokens,
969-
moe_load_balancer=moe_load_balancer,
970-
lora_config=lora_config,
971-
allreduce_strategy=self.pytorch_backend_config.allreduce_strategy,
972-
mm_encoder_only=self.pytorch_backend_config.mm_encoder_only,
973-
**kwargs)
974-
975-
validate_and_set_kv_cache_quant(
976-
config, self.pytorch_backend_config.kv_cache_dtype)
977-
validate_and_set_mamba_ssm_cache_dtype(
978-
config, self.pytorch_backend_config.mamba_ssm_cache_dtype)
979-
980-
num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0"))
981-
if num_layers > 0:
982-
config.pretrained_config.num_hidden_layers = num_layers
983-
for sub_config in ["text_config", "vision_config"]:
984-
if hasattr(config.pretrained_config, sub_config):
985-
getattr(config.pretrained_config,
986-
sub_config).num_hidden_layers = num_layers
987-
988-
with timing("Model init total"), maybe_create_moe_load_balancer(
989-
config, self.mapping) as moe_load_balancer:
990-
991-
try:
992-
# config will be modified in-place for some models, like Qwen2
993-
config_copy = copy.deepcopy(config)
994-
with MetaInitMode():
995-
model = AutoModelForCausalLM.from_config(config_copy)
996-
997-
memo = dict()
998-
999-
def init_meta_tensor(t: torch.Tensor):
1000-
if t.device != torch.device('meta'):
1001-
return t
1002-
if t not in memo:
1003-
memo[t] = torch.empty_like(t, device='cuda')
1004-
return memo[t]
1005-
1006-
model._apply(init_meta_tensor)
1007-
config = config_copy
1008-
1009-
except Exception:
1010-
logger.info(
1011-
f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n"
1012-
)
1013-
model = AutoModelForCausalLM.from_config(config)
1014-
1015-
model.to("cuda")
1016-
rank_model_storage = get_rank_model_storage(model)
1017-
logger.info(
1018-
f"Use {rank_model_storage / (1024**3):.2f} GB for model weights."
1019-
)
1020-
if load_format == LoadFormat.AUTO:
1021-
if hasattr(model, 'llm_checkpoint_dir'):
1022-
weights = checkpoint_loader.load_weights(
1023-
model.llm_checkpoint_dir)
1024-
else:
1025-
weights = checkpoint_loader.load_weights(checkpoint_dir)
1026-
1027-
weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
1028-
model, config)
1029-
self._call_load_weights(model.load_weights, weights,
1030-
weight_mapper)
1031-
1032-
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
1033-
):
1034-
weights = checkpoint_loader.load_weights(
1035-
self.spec_config.speculative_model_dir)
1036-
self._call_load_weights(model.load_draft_weights, weights,
1037-
weight_mapper)
1038-
1039-
elif load_format == LoadFormat.DUMMY:
1040-
initialize_dummy_weights(model)
1041-
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
1042-
):
1043-
model.draft_model.load_weights_from_target_model(model)
1044-
1045-
elif load_format == LoadFormat.VISION_ONLY:
1046-
# Vision weights are already loaded within the model.
1047-
logger.info(
1048-
"LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights."
1049-
)
1050-
1051-
else:
1052-
raise NotImplementedError(
1053-
f"No load support for load format: {load_format}")
1054-
1055-
if isinstance(moe_load_balancer, MoeLoadBalancer):
1056-
setattr(self, "moe_load_balancer", moe_load_balancer)
1057-
moe_load_balancer.register_weight_slots_after_to_cuda()
1058-
logger.info("moe_load_balancer finalizing model...")
1059-
moe_load_balancer.finalize_model()
1060-
logger.info("moe_load_balancer finalize model done")
1061-
1062-
torch.cuda.current_stream().synchronize()
1063-
1064-
if drafting_loop_wrapper is not None:
1065-
model = drafting_loop_wrapper(model)
1066-
self.model_is_wrapped = True
1067-
else:
1068-
self.model_is_wrapped = False
1069-
1070-
return model
1071-
1072-
def _call_load_weights(self, load_method, weights, weight_mapper):
1073-
# TODO smor- this is a temporary solution to load weights.
1074-
# Once checkpoint format is unified, this method will be removed.
1075-
from inspect import getfullargspec
1076-
args = getfullargspec(load_method).args
1077-
if "weight_mapper" in args:
1078-
load_method(weights, weight_mapper=weight_mapper)
1079-
else:
1080-
load_method(weights)
1081-
1082805
def _init_max_seq_len(self):
1083806
# For mm_encoder_only mode, infer_max_seq_len() is for LLM decoder models
1084807
if hasattr(self.model, 'infer_max_seq_len'):

0 commit comments

Comments
 (0)