@@ -186,15 +186,21 @@ def __init__(
186186 def _init_pooler (self , vllm_config : "VllmConfig" , prefix : str = "" ):
187187 raise NotImplementedError
188188
189- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
189+ def load_weights (
190+ self ,
191+ weights : Iterable [tuple [str , torch .Tensor ]],
192+ load_lm_head : bool = False ,
193+ ):
190194 # TODO: Support uninitialized params tracking
191195
192- # We have deleted this attribute, so don't load it
193- weights = (
194- (name , data )
195- for name , data in weights
196- if not name .startswith ("lm_head." )
197- )
196+ # For most pooling models: We have deleted this attribute, so don't load it.
197+ # For converting an LLM into a seq cls model, we need the lm_head.
198+ if not load_lm_head :
199+ weights = (
200+ (name , data )
201+ for name , data in weights
202+ if not name .startswith ("lm_head." )
203+ )
198204
199205 # If `*ForCausalLM` defines `load_weights` on the inner model
200206 # and there are no other inner modules with parameters,
@@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax(
431437 )
432438 model .lm_head = model .lm_head .tie_weights (embed_tokens )
433439
434- # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
435- loaded_weights = type (model ).__mro__ [1 ].load_weights (model , weights )
440+ # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
441+ # function, so we need use this hacky method to obtain it.
442+ pooling_model_cls = next (
443+ x for x in type (model ).__mro__ if x .__name__ == "ModelForPooling"
444+ )
445+ loaded_weights = pooling_model_cls .load_weights (model , weights , load_lm_head = True )
436446
437447 from vllm .transformers_utils .tokenizer import get_tokenizer
438448
0 commit comments