Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/model_executor/model_loader/test_bitsandbytes_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader,
)


class _DummyBitsAndBytesLoader(BitsAndBytesModelLoader):
"""Test helper that bypasses any real HF interactions."""

def __init__(
self, load_config: LoadConfig, mock_result: tuple[str, list[str], str]
):
super().__init__(load_config)
self._mock_result = mock_result

def _get_weight_files( # type: ignore[override]
self,
model_name_or_path: str,
allowed_patterns: list[str],
revision: str | None = None,
) -> tuple[str, list[str], str]:
return self._mock_result


def test_bitsandbytes_loader_detects_safetensors_from_files(tmp_path):
"""Even if the allow-pattern looks like *.bin, safetensors files are detected."""

llm_dir = tmp_path / "llm"
llm_dir.mkdir()
safetensor = llm_dir / "model-00001-of-00002.safetensors"
safetensor.write_bytes(b"test")

load_config = LoadConfig()
loader = _DummyBitsAndBytesLoader(
load_config,
mock_result=(str(tmp_path), [str(safetensor)], "*.bin"),
)

files, use_safetensors = loader._prepare_weights(str(tmp_path), revision=None)

assert use_safetensors is True
assert files == [str(safetensor)]
35 changes: 35 additions & 0 deletions tests/model_executor/model_loader/test_default_loader_subfolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader


def test_default_loader_prefers_llm_subfolder_and_filters_with_index(tmp_path):
# Create local repo layout with llm/ subfolder
llm_dir = tmp_path / "llm"
llm_dir.mkdir()

keep = llm_dir / "model-00001-of-00002.safetensors"
drop = llm_dir / "model-00002-of-00002.safetensors"
keep.write_bytes(b"0")
drop.write_bytes(b"0")

# Create index file within llm/ that only references the first shard
index = llm_dir / "model.safetensors.index.json"
index.write_text(json.dumps({"weight_map": {"w": keep.name}}))

# Default loader in auto format should find llm/*.safetensors and use the subfolder index

Check failure on line 24 in tests/model_executor/model_loader/test_default_loader_subfolder.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/model_executor/model_loader/test_default_loader_subfolder.py:24:89: E501 Line too long (93 > 88)
loader = DefaultModelLoader(LoadConfig(load_format="auto"))
hf_folder, files, use_safetensors = loader._prepare_weights(
str(tmp_path),
revision=None,
fall_back_to_pt=True,
allow_patterns_overrides=None,
)

assert hf_folder == str(tmp_path)
assert use_safetensors is True
assert files == [str(keep)]
24 changes: 24 additions & 0 deletions tests/model_executor/test_weight_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import os
import tempfile

Expand All @@ -11,6 +12,7 @@
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf,
enable_hf_transfer,
filter_duplicate_safetensors_files,
)


Expand Down Expand Up @@ -61,6 +63,28 @@ def test_download_weights_from_hf():
)


def test_filter_duplicate_safetensors_files_with_subfolder(tmp_path):
llm_dir = tmp_path / "llm"
llm_dir.mkdir()
kept_file = llm_dir / "model-00001-of-00002.safetensors"
kept_file.write_bytes(b"0")
dropped_file = tmp_path / "other.safetensors"
dropped_file.write_bytes(b"0")

index_path = llm_dir / "model.safetensors.index.json"
index_path.write_text(
json.dumps({"weight_map": {"w": "model-00001-of-00002.safetensors"}})
)

filtered = filter_duplicate_safetensors_files(
[str(kept_file), str(dropped_file)],
str(tmp_path),
"llm/model.safetensors.index.json",
)

assert filtered == [str(kept_file)]


if __name__ == "__main__":
test_hf_transfer_auto_activation()
test_download_weights_from_hf()
53 changes: 53 additions & 0 deletions tests/tokenization/test_tokenizer_llm_subfolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any

from vllm.transformers_utils import tokenizer as tokenizer_module
from vllm.transformers_utils.tokenizer import get_tokenizer


class _DummyTokenizer:
def __init__(self):
self.all_special_ids: list[int] = []
self.all_special_tokens: list[str] = []
self.all_special_tokens_extended: list[str] = []
self.special_tokens_map: dict[str, str] = {}
self.vocab_size = 1

def get_vocab(self) -> dict[str, int]:
return {"a": 0}

def __len__(self) -> int: # pragma: no cover - trivial
return 1

def decode(self, *args: Any, **kwargs: Any) -> str:
return ""

def encode(self, *args: Any, **kwargs: Any) -> list[int]:
return []


def test_tokenizer_prefers_llm_subfolder(monkeypatch):
captured = {}

def fake_file_exists(repo_id: str, file_name: str, **kwargs: Any) -> bool:
return file_name == "llm/tokenizer.json"

def fake_auto_from_pretrained(*args: Any, **kwargs: Any):
captured["subfolder"] = kwargs.get("subfolder")
return _DummyTokenizer()

monkeypatch.setattr(tokenizer_module, "file_exists", fake_file_exists)
monkeypatch.setattr(
tokenizer_module.AutoTokenizer,
"from_pretrained",
classmethod(
lambda cls, *args, **kwargs: fake_auto_from_pretrained(*args, **kwargs)
),
)

tokenizer = get_tokenizer("fake/model")

assert tokenizer is not None
assert captured["subfolder"] == "llm"
71 changes: 71 additions & 0 deletions tests/transformers_utils/test_hf_config_parser_subfolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from transformers import GenerationConfig, PretrainedConfig

from vllm.transformers_utils import config as config_module
from vllm.transformers_utils.config import HFConfigParser, try_get_generation_config


def test_hf_config_parser_uses_llm_subfolder(monkeypatch):
parser = HFConfigParser()
base_config = PretrainedConfig()
subfolder_config = PretrainedConfig()

def fake_get_config_dict(
cls,
model: str | bytes,
revision: str | None = None,
code_revision: str | None = None,
**kwargs,
):
return {"llm_cfg": {}}, base_config

def fake_file_exists(model: str | bytes, config_name: str, revision: str | None):
return config_name == "llm/config.json"

auto_called = {}

def fake_auto_from_pretrained(cls, *args, **kwargs):
auto_called["subfolder"] = kwargs.get("subfolder")
return subfolder_config

monkeypatch.setattr(
PretrainedConfig,
"get_config_dict",
classmethod(fake_get_config_dict),
)
monkeypatch.setattr(config_module, "file_or_path_exists", fake_file_exists)
monkeypatch.setattr(
config_module.AutoConfig,
"from_pretrained",
classmethod(fake_auto_from_pretrained),
)

returned_dict, returned_config = parser.parse("fake/model", trust_remote_code=False)

assert returned_dict == {"llm_cfg": {}}
assert returned_config is subfolder_config
assert auto_called["subfolder"] == "llm"


def test_try_get_generation_config_llm_subfolder(monkeypatch):
calls = []

def fake_from_pretrained(cls, model: str, **kwargs):
calls.append(kwargs.get("subfolder"))
if len(calls) == 1:
raise OSError("missing")
return GenerationConfig()

monkeypatch.setattr(
config_module.GenerationConfig,
"from_pretrained",
classmethod(fake_from_pretrained),
)

result = try_get_generation_config("fake/model", trust_remote_code=False)

assert isinstance(result, GenerationConfig)
assert calls == [None, "llm"]
58 changes: 42 additions & 16 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, load_config: LoadConfig):
self.pre_quant: bool = False
self.load_8bit: bool = False
self.is_pool_model: bool = False
self.allow_patterns_overrides: list[str] | None = None

def _get_weight_files(
self,
Expand All @@ -97,14 +98,27 @@ def _get_weight_files(
is_local = os.path.isdir(model_name_or_path)

if is_local:
for pattern in allowed_patterns:
patterns = list(allowed_patterns)
# Prefer subfolder patterns if common subfolder exists locally.
if os.path.isdir(os.path.join(model_name_or_path, "llm")):
patterns = [f"llm/{p}" for p in allowed_patterns] + patterns
for pattern in patterns:
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
if weight_files:
return model_name_or_path, weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
search_patterns = list(allowed_patterns)
# Prefer 'llm/' weights when present in the repo.
if any(
f.startswith("llm/") and f.endswith((".safetensors", ".bin", ".pt"))
for f in repo_files
):
search_patterns = [
f"llm/{p}" for p in allowed_patterns
] + search_patterns
for pattern in search_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
Expand All @@ -129,26 +143,36 @@ def _prepare_weights(

allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]

allow_patterns_overrides = getattr(self, "allow_patterns_overrides", None)
if allow_patterns_overrides is not None:
allowed_patterns = list(allow_patterns_overrides)

hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision
)

use_safetensors = matched_pattern == "*.safetensors"
# Detect safetensors robustly (pattern may include subfolder)
use_safetensors = matched_pattern.endswith(".safetensors")
# Additionally guard by checking actual files
if not use_safetensors:
use_safetensors = any(f.endswith(".safetensors") for f in hf_weights_files)
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
# If weights live under a subfolder (e.g., 'llm/*.safetensors'),
# the index file will also live there.
if "/" in matched_pattern:
folder_prefix = matched_pattern.rsplit("/", 1)[0] + "/"
else:
folder_prefix = ""
index_file = folder_prefix + SAFE_WEIGHTS_INDEX_NAME
if use_safetensors and not is_local:
# Download index for safetensors to select correct shards.
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
Expand Down Expand Up @@ -587,6 +611,8 @@ def _initialize_loader_state(
self._get_bnb_target_modules(model)
self._classify_module_sharding(model)

self.allow_patterns_overrides = getattr(model, "allow_patterns_overrides", None)

def _dequantize_dq(self, quant_states: Any):
"""
When BNB employs Double Quantization, we perform the dequantization of
Expand Down
Loading
Loading