|
1 | 1 | import bisect |
2 | 2 | import contextlib |
3 | | -import copy |
4 | 3 | import functools |
5 | 4 | import gc |
6 | 5 | import inspect |
7 | 6 | import math |
8 | | -import os |
9 | | -import traceback |
10 | 7 | import weakref |
11 | 8 | from abc import ABC, abstractmethod |
12 | 9 | from contextlib import contextmanager |
|
17 | 14 |
|
18 | 15 | import tensorrt_llm.bindings.internal.userbuffers as ub |
19 | 16 | from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, |
20 | | - str_dtype_to_torch, torch_dtype_to_str, |
21 | | - trace_func) |
| 17 | + torch_dtype_to_str, trace_func) |
22 | 18 | from tensorrt_llm.inputs.multimodal import (MultimodalParams, |
23 | 19 | MultimodalRuntimeData) |
24 | 20 | from tensorrt_llm.logger import logger |
25 | 21 | from tensorrt_llm.lora_helper import LoraConfig |
26 | 22 | from tensorrt_llm.lora_manager import LoraModelConfig |
27 | 23 | from tensorrt_llm.mapping import CpType, Mapping |
28 | | -from tensorrt_llm.models.modeling_utils import QuantAlgo |
29 | | -from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 |
30 | 24 |
|
31 | 25 | from ..attention_backend.interface import (AttentionMetadata, |
32 | 26 | AttentionRuntimeFeatures) |
|
40 | 34 | from ..distributed.communicator import init_pp_comm |
41 | 35 | from ..expert_statistic import ExpertStatistic |
42 | 36 | from ..metadata import KVCacheParams |
43 | | -from ..model_config import ModelConfig, MoeLoadBalancerConfig |
44 | | -from ..models import AutoModelForCausalLM |
45 | 37 | from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader |
46 | | -from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, |
47 | | - timing) |
48 | | -from ..modules.fused_moe.moe_load_balancer import ( |
49 | | - MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer) |
| 38 | +from ..models.modeling_utils import DecoderModelForCausalLM |
| 39 | +from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, |
| 40 | + MoeLoadBalancerIterContext) |
50 | 41 | from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, |
51 | 42 | get_spec_metadata, |
52 | 43 | update_spec_config_from_model_config) |
|
55 | 46 | from ..utils import (get_model_extra_attrs, |
56 | 47 | set_per_request_piecewise_cuda_graph_flag, |
57 | 48 | set_torch_compiling, with_model_extra_attrs) |
58 | | -from .config import LoadFormat, PyTorchConfig |
| 49 | +from .config import PyTorchConfig |
59 | 50 | from .config_utils import is_mla |
60 | 51 | from .cuda_graph_runner import CUDAGraphRunner |
61 | 52 | from .guided_decoder import CapturableGuidedDecoder |
62 | 53 | from .layerwise_nvtx_marker import LayerwiseNvtxMarker |
63 | 54 | from .llm_request import get_draft_token_length |
| 55 | +from .model_loader import ModelLoader |
64 | 56 | from .resource_manager import (BaseResourceManager, KVCacheManager, |
65 | 57 | ResourceManager, ResourceManagerType) |
66 | 58 | from .sampler import SampleStateTensors |
@@ -95,137 +87,6 @@ def warmup(self, resource_manager: ResourceManager) -> None: |
95 | 87 | return |
96 | 88 |
|
97 | 89 |
|
98 | | -_KV_CACHE_MAP = { |
99 | | - "fp8": QuantAlgo.FP8.value, |
100 | | - "nvfp4": QuantAlgo.NVFP4.value, |
101 | | - "auto": "auto" |
102 | | -} |
103 | | -_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") |
104 | | - |
105 | | - |
106 | | -def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, |
107 | | - mamba_ssm_cache_dtype: str) -> None: |
108 | | - if mamba_ssm_cache_dtype == "auto": |
109 | | - mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype |
110 | | - else: |
111 | | - mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) |
112 | | - |
113 | | - config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype |
114 | | - |
115 | | - |
116 | | -def validate_and_set_kv_cache_quant(model_config: ModelConfig, |
117 | | - pyt_kv_cache_dtype: str) -> QuantAlgo: |
118 | | - logger.info( |
119 | | - f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"' |
120 | | - ) |
121 | | - # Quantization from hf_quant_config.json |
122 | | - kv_cache_quant = model_config.quant_config.kv_cache_quant_algo |
123 | | - # PyTorch configuration quantization |
124 | | - valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES) |
125 | | - mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None) |
126 | | - |
127 | | - # If we're letting the checkpoint dictate the quant with auto, simply |
128 | | - # return and do not modify the checkpoint. |
129 | | - if pyt_kv_cache_dtype == "auto": |
130 | | - logger.info( |
131 | | - f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using ' |
132 | | - "checkpoint KV quantization.") |
133 | | - return |
134 | | - |
135 | | - # If we have an invalid quantization, simply raise an exception. |
136 | | - if not valid_pyt_quant: |
137 | | - raise ValueError( |
138 | | - "Overriding KV cache quantization with an invalid type " |
139 | | - f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" ' |
140 | | - f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".') |
141 | | - |
142 | | - # If we get to this point we have a valid quantization setting, but if |
143 | | - # we have an existing setting and it doesn't match we shouldn't proceed. |
144 | | - if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant: |
145 | | - raise RuntimeError( |
146 | | - "Attempting to override KV cache quantization " |
147 | | - f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype=' |
148 | | - f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a ' |
149 | | - "pre-quantized KV cache that doesn't match.") |
150 | | - |
151 | | - # We have an open ended KV cache in the checkpoint |
152 | | - # and we have a specified override. |
153 | | - model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant |
154 | | - |
155 | | - |
156 | | -def initialize_dummy_weights( |
157 | | - model: torch.nn.Module, |
158 | | - low: float = -1e-3, |
159 | | - high: float = 1e-3, |
160 | | - seed: int = 0, |
161 | | -) -> None: |
162 | | - """ |
163 | | - This is similar to this function in SGLang with a few changes: |
164 | | - https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577 |
165 | | -
|
166 | | - This method is used to initialize weights with dummy values for testing |
167 | | - models without checkpoints. Unquantized (FP16/BF16/etc) values are generated |
168 | | - from a uniform distribution over the interval (low, high). |
169 | | -
|
170 | | - For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values. |
171 | | - We simply generate values uniformly across an interval that has been empirically verified |
172 | | - to not generate NaNs/inf for these. |
173 | | - """ |
174 | | - |
175 | | - def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: |
176 | | - # These values are not necessarily the largest possible min/max, |
177 | | - # they need to be small enough to avoid NaNs. |
178 | | - if dtype in (torch.float8_e4m3fn, torch.int8): |
179 | | - return (-3.0, 3.0) |
180 | | - |
181 | | - elif dtype == float4_e2m1x2: |
182 | | - # These correspond to bits of 2 packed FP4 values. |
183 | | - # Because we only go up to 64, the high 4 bits will |
184 | | - # always be 0. But this is fine - we just need values |
185 | | - # that won't generate NaNs. |
186 | | - return (0, 64) |
187 | | - |
188 | | - else: |
189 | | - raise NotImplementedError(f"Unknown quantized type: {dtype}.") |
190 | | - |
191 | | - for param in model.state_dict().values(): |
192 | | - generator = torch.Generator(device=param.data.device) |
193 | | - generator.manual_seed(seed) |
194 | | - dtype = param.data.dtype |
195 | | - |
196 | | - if param.data.element_size() < 2: |
197 | | - # We need to do a cast/round since torch doesn't have uniform_ |
198 | | - # support for these dtypes. |
199 | | - tmp_param = torch.empty(param.data.shape, |
200 | | - dtype=torch.float16, |
201 | | - device=param.data.device) |
202 | | - |
203 | | - quant_min, quant_max = _get_random_min_max(dtype) |
204 | | - tmp_param = tmp_param.uniform_(quant_min, |
205 | | - quant_max, |
206 | | - generator=generator) |
207 | | - |
208 | | - param.data.copy_(tmp_param.to(dtype)) |
209 | | - |
210 | | - # Note: no need to to mess with int32 params, these are probably |
211 | | - # constants and not weights. |
212 | | - elif torch.is_floating_point(param): |
213 | | - param.uniform_(low, high, generator=generator) |
214 | | - |
215 | | - |
216 | | -def get_rank_model_storage(model): |
217 | | - total_bytes = 0 |
218 | | - for _, param in model.named_parameters(): |
219 | | - if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device( |
220 | | - ): |
221 | | - total_bytes += param.element_size() * param.nelement() |
222 | | - for _, buf in model.named_buffers(): |
223 | | - if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device( |
224 | | - ): |
225 | | - total_bytes += buf.element_size() * buf.nelement() |
226 | | - return total_bytes |
227 | | - |
228 | | - |
229 | 90 | def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], |
230 | 91 | max_batch_size: int, max_num_tokens: int, |
231 | 92 | max_draft_len: int, |
@@ -302,20 +163,17 @@ def __init__( |
302 | 163 | ) |
303 | 164 |
|
304 | 165 | attn_backend = pytorch_backend_config.attn_backend |
305 | | - self.model = self._load_model( |
306 | | - model_path, |
| 166 | + loader = ModelLoader( |
| 167 | + pytorch_backend_config=pytorch_backend_config, |
307 | 168 | mapping=self.mapping, |
308 | | - checkpoint_loader=checkpoint_loader, |
309 | | - attn_backend=attn_backend, |
310 | | - moe_backend=pytorch_backend_config.moe_backend, |
311 | | - moe_disable_finalize_fusion=pytorch_backend_config. |
312 | | - moe_disable_finalize_fusion, |
313 | | - load_format=pytorch_backend_config.load_format, |
| 169 | + spec_config=self.spec_config, |
314 | 170 | max_num_tokens=max_num_tokens, |
315 | | - moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens, |
316 | | - moe_load_balancer=pytorch_backend_config.moe_load_balancer, |
| 171 | + max_seq_len=max_seq_len, |
317 | 172 | lora_config=lora_config, |
318 | | - drafting_loop_wrapper=drafting_loop_wrapper) |
| 173 | + ) |
| 174 | + self.model = loader.load(checkpoint_dir=model_path, |
| 175 | + checkpoint_loader=checkpoint_loader, |
| 176 | + drafting_loop_wrapper=drafting_loop_wrapper) |
319 | 177 | # In case that some tests use stub models and override `_load_model`. |
320 | 178 | if not hasattr(self.model, 'extra_attrs'): |
321 | 179 | self.model.extra_attrs = {} |
@@ -944,141 +802,6 @@ def __del__(self) -> None: |
944 | 802 | # Release model weights. |
945 | 803 | release_gc() |
946 | 804 |
|
947 | | - def _load_model(self, |
948 | | - checkpoint_dir: str, |
949 | | - checkpoint_loader: BaseCheckpointLoader, |
950 | | - load_format: LoadFormat, |
951 | | - max_num_tokens: int, |
952 | | - moe_max_num_tokens: Optional[int] = None, |
953 | | - moe_load_balancer: Optional[MoeLoadBalancerConfig] = None, |
954 | | - lora_config: Optional[LoraConfig] = None, |
955 | | - drafting_loop_wrapper: Optional[Callable[ |
956 | | - [torch.nn.Module], torch.nn.Module]] = None, |
957 | | - **kwargs) -> DecoderModelForCausalLM: |
958 | | - config = checkpoint_loader.load_config( |
959 | | - checkpoint_dir, |
960 | | - trust_remote_code=True, |
961 | | - enable_min_latency=self.pytorch_backend_config.enable_min_latency, |
962 | | - use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, |
963 | | - force_dynamic_quantization=self.pytorch_backend_config. |
964 | | - force_dynamic_quantization, |
965 | | - spec_config=self.spec_config, |
966 | | - max_num_tokens=max_num_tokens, |
967 | | - max_seq_len=self.max_seq_len, |
968 | | - moe_max_num_tokens=moe_max_num_tokens, |
969 | | - moe_load_balancer=moe_load_balancer, |
970 | | - lora_config=lora_config, |
971 | | - allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, |
972 | | - mm_encoder_only=self.pytorch_backend_config.mm_encoder_only, |
973 | | - **kwargs) |
974 | | - |
975 | | - validate_and_set_kv_cache_quant( |
976 | | - config, self.pytorch_backend_config.kv_cache_dtype) |
977 | | - validate_and_set_mamba_ssm_cache_dtype( |
978 | | - config, self.pytorch_backend_config.mamba_ssm_cache_dtype) |
979 | | - |
980 | | - num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) |
981 | | - if num_layers > 0: |
982 | | - config.pretrained_config.num_hidden_layers = num_layers |
983 | | - for sub_config in ["text_config", "vision_config"]: |
984 | | - if hasattr(config.pretrained_config, sub_config): |
985 | | - getattr(config.pretrained_config, |
986 | | - sub_config).num_hidden_layers = num_layers |
987 | | - |
988 | | - with timing("Model init total"), maybe_create_moe_load_balancer( |
989 | | - config, self.mapping) as moe_load_balancer: |
990 | | - |
991 | | - try: |
992 | | - # config will be modified in-place for some models, like Qwen2 |
993 | | - config_copy = copy.deepcopy(config) |
994 | | - with MetaInitMode(): |
995 | | - model = AutoModelForCausalLM.from_config(config_copy) |
996 | | - |
997 | | - memo = dict() |
998 | | - |
999 | | - def init_meta_tensor(t: torch.Tensor): |
1000 | | - if t.device != torch.device('meta'): |
1001 | | - return t |
1002 | | - if t not in memo: |
1003 | | - memo[t] = torch.empty_like(t, device='cuda') |
1004 | | - return memo[t] |
1005 | | - |
1006 | | - model._apply(init_meta_tensor) |
1007 | | - config = config_copy |
1008 | | - |
1009 | | - except Exception: |
1010 | | - logger.info( |
1011 | | - f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n" |
1012 | | - ) |
1013 | | - model = AutoModelForCausalLM.from_config(config) |
1014 | | - |
1015 | | - model.to("cuda") |
1016 | | - rank_model_storage = get_rank_model_storage(model) |
1017 | | - logger.info( |
1018 | | - f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." |
1019 | | - ) |
1020 | | - if load_format == LoadFormat.AUTO: |
1021 | | - if hasattr(model, 'llm_checkpoint_dir'): |
1022 | | - weights = checkpoint_loader.load_weights( |
1023 | | - model.llm_checkpoint_dir) |
1024 | | - else: |
1025 | | - weights = checkpoint_loader.load_weights(checkpoint_dir) |
1026 | | - |
1027 | | - weight_mapper = checkpoint_loader.get_initialized_weight_mapper( |
1028 | | - model, config) |
1029 | | - self._call_load_weights(model.load_weights, weights, |
1030 | | - weight_mapper) |
1031 | | - |
1032 | | - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( |
1033 | | - ): |
1034 | | - weights = checkpoint_loader.load_weights( |
1035 | | - self.spec_config.speculative_model_dir) |
1036 | | - self._call_load_weights(model.load_draft_weights, weights, |
1037 | | - weight_mapper) |
1038 | | - |
1039 | | - elif load_format == LoadFormat.DUMMY: |
1040 | | - initialize_dummy_weights(model) |
1041 | | - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( |
1042 | | - ): |
1043 | | - model.draft_model.load_weights_from_target_model(model) |
1044 | | - |
1045 | | - elif load_format == LoadFormat.VISION_ONLY: |
1046 | | - # Vision weights are already loaded within the model. |
1047 | | - logger.info( |
1048 | | - "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." |
1049 | | - ) |
1050 | | - |
1051 | | - else: |
1052 | | - raise NotImplementedError( |
1053 | | - f"No load support for load format: {load_format}") |
1054 | | - |
1055 | | - if isinstance(moe_load_balancer, MoeLoadBalancer): |
1056 | | - setattr(self, "moe_load_balancer", moe_load_balancer) |
1057 | | - moe_load_balancer.register_weight_slots_after_to_cuda() |
1058 | | - logger.info("moe_load_balancer finalizing model...") |
1059 | | - moe_load_balancer.finalize_model() |
1060 | | - logger.info("moe_load_balancer finalize model done") |
1061 | | - |
1062 | | - torch.cuda.current_stream().synchronize() |
1063 | | - |
1064 | | - if drafting_loop_wrapper is not None: |
1065 | | - model = drafting_loop_wrapper(model) |
1066 | | - self.model_is_wrapped = True |
1067 | | - else: |
1068 | | - self.model_is_wrapped = False |
1069 | | - |
1070 | | - return model |
1071 | | - |
1072 | | - def _call_load_weights(self, load_method, weights, weight_mapper): |
1073 | | - # TODO smor- this is a temporary solution to load weights. |
1074 | | - # Once checkpoint format is unified, this method will be removed. |
1075 | | - from inspect import getfullargspec |
1076 | | - args = getfullargspec(load_method).args |
1077 | | - if "weight_mapper" in args: |
1078 | | - load_method(weights, weight_mapper=weight_mapper) |
1079 | | - else: |
1080 | | - load_method(weights) |
1081 | | - |
1082 | 805 | def _init_max_seq_len(self): |
1083 | 806 | # For mm_encoder_only mode, infer_max_seq_len() is for LLM decoder models |
1084 | 807 | if hasattr(self.model, 'infer_max_seq_len'): |
|
0 commit comments