Skip to content

Commit 802748b

Browse files
authored
[Bugfix] Fix Qwen3-Reranker-8B load (vllm-project#28117)
Signed-off-by: wang.yuqi <[email protected]>
1 parent faedbb4 commit 802748b

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

vllm/model_executor/models/adapters.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,21 @@ def __init__(
186186
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
187187
raise NotImplementedError
188188

189-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
189+
def load_weights(
190+
self,
191+
weights: Iterable[tuple[str, torch.Tensor]],
192+
load_lm_head: bool = False,
193+
):
190194
# TODO: Support uninitialized params tracking
191195

192-
# We have deleted this attribute, so don't load it
193-
weights = (
194-
(name, data)
195-
for name, data in weights
196-
if not name.startswith("lm_head.")
197-
)
196+
# For most pooling models: We have deleted this attribute, so don't load it.
197+
# For converting an LLM into a seq cls model, we need the lm_head.
198+
if not load_lm_head:
199+
weights = (
200+
(name, data)
201+
for name, data in weights
202+
if not name.startswith("lm_head.")
203+
)
198204

199205
# If `*ForCausalLM` defines `load_weights` on the inner model
200206
# and there are no other inner modules with parameters,
@@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax(
431437
)
432438
model.lm_head = model.lm_head.tie_weights(embed_tokens)
433439

434-
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
435-
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
440+
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
441+
# function, so we need use this hacky method to obtain it.
442+
pooling_model_cls = next(
443+
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
444+
)
445+
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
436446

437447
from vllm.transformers_utils.tokenizer import get_tokenizer
438448

0 commit comments

Comments
 (0)