Skip to content

Commit 1d64287

Browse files
authored
[torchao] fix safetensors for sharding (vllm-project#28169)
Signed-off-by: Angel Li <[email protected]>
1 parent 9ccef8e commit 1d64287

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

tests/quantization/test_torchao.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,12 @@ def test_reload_weights():
225225
@pytest.mark.skip(
226226
reason="since torchao nightly is only compatible with torch nightly"
227227
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
228-
"torchao tests that requires newer versions (0.14.0.dev+) for now"
228+
"torchao tests that requires newer versions (0.15.0.dev+) for now"
229229
)
230-
def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner):
230+
def test_safetensors_model_loading_with_params(vllm_runner):
231231
torch._dynamo.reset()
232-
model_name = (
233-
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
234-
)
232+
# using this model to test safetensors loading with file sharding
233+
model_name = "torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors"
235234
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
236235
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
237236

vllm/model_executor/model_loader/default_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
279279
if (
280280
hasattr(quant_config, "is_checkpoint_torchao_serialized")
281281
and quant_config.is_checkpoint_torchao_serialized
282-
and torchao_version_at_least("0.14.0")
282+
and torchao_version_at_least("0.15.0")
283283
):
284284
self.load_config.safetensors_load_strategy = "torchao"
285285

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,9 @@ def safetensors_weights_iterator(
595595
if safetensors_load_strategy == "eager":
596596
loading_desc += " (eager)"
597597

598+
state_dict = {}
599+
leftover_state_dict: dict[str, torch.Tensor] = {}
600+
598601
for st_file in tqdm(
599602
hf_weights_files,
600603
desc=loading_desc,
@@ -606,22 +609,32 @@ def safetensors_weights_iterator(
606609
state_dict = load(f.read())
607610
yield from state_dict.items()
608611
elif safetensors_load_strategy == "torchao":
609-
if not torchao_version_at_least("0.14.0"):
612+
# we can't load flattened torchao tensor subclasses directly into the model
613+
# instead we reconstruct the subclasses here before returning
614+
if not torchao_version_at_least("0.15.0"):
610615
raise ValueError(
611-
"Please use torchao version >= 0.14.0 \
616+
"Please use torchao version >= 0.15.0 \
612617
to load torchao safetensors checkpoint"
613618
)
614619
from torchao.prototype.safetensors.safetensors_support import (
615620
unflatten_tensor_state_dict,
616621
)
617622

618623
with safe_open(st_file, framework="pt") as f:
619-
state_dict = {}
620624
for name in f.keys(): # noqa: SIM118
621625
state_dict[name] = f.get_tensor(name)
626+
627+
# update with leftover tensor data from previous iteration, if any
628+
state_dict.update(leftover_state_dict)
622629
metadata = f.metadata()
623-
updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata)
624-
yield from updated_state_dict.items()
630+
# due to sharded checkpoints, we are not guaranteed that we have all
631+
# tensor subclass data on one file
632+
# state_dict has the leftover data from this step and we wait for
633+
# missing information to be provided in a future iteration
634+
unflattened_state_dict, leftover_state_dict = (
635+
unflatten_tensor_state_dict(state_dict, metadata)
636+
)
637+
yield from unflattened_state_dict.items()
625638
else:
626639
with safe_open(st_file, framework="pt") as f:
627640
for name in f.keys(): # noqa: SIM118

0 commit comments

Comments
 (0)