-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Description
Your current environment
- vLLM Version: 0.11.2
- Transformers Version: 4.57
- Model: Qwen3VLForConditionalGeneration
🐛 Describe the bug
I have observed an inconsistency in the output of the forward method for the Qwen3VLForConditionalGeneration class between vLLM (version 0.11.2) and Transformers (version 4.57).
In the Transformers library, the last hidden state (outputs.hidden_states[0, -1, :]) returned is before the final layer normalization. However, in vLLM, the returned hidden_states appears to be after the normalization is applied.
Is this discrepancy an unintended bug, or is there a configuration option in vLLM to control this output behavior (e.g., to return the pre-norm hidden states)?
I don't have minimal demo, but I change the origin code to test.
Because theforward method of Qwen3VLForConditionalGeneration has the following code:
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
# args for deepstack
deepstack_input_embeds=deepstack_input_embeds,
)The type of self.language_model.model is Qwen3LLMModel.
I introduced an environment variable:LAST_HIDDEN_STATE_NOT_NORM before return of Qwen3LLMModel 's forward method:
if os.environ.get("LAST_HIDDEN_STATE_NOT_NORM", "0") == "1":
return hidden_states + residual
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_statesWhen LAST_HIDDEN_STATE_NOT_NORM=1 is set, hidden states output exactly match Transformers' behavior.
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.