|
30 | 30 | GemmaRMSNorm as Qwen3NextRMSNorm) |
31 | 31 | # yapf: enable |
32 | 32 | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
33 | | - MergedColumnParallelLinear, |
34 | 33 | QKVParallelLinear, |
35 | 34 | ReplicatedLinear, |
36 | 35 | RowParallelLinear) |
@@ -254,12 +253,20 @@ def __init__( |
254 | 253 | # projection of the input hidden states |
255 | 254 | self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 |
256 | 255 | self.projection_size_ba = self.num_v_heads * 2 |
257 | | - self.in_proj = MergedColumnParallelLinear( |
| 256 | + self.in_proj_qkvz = ColumnParallelLinear( |
258 | 257 | input_size=self.hidden_size, |
259 | | - output_sizes=[self.projection_size_qkvz, self.projection_size_ba], |
| 258 | + output_size=self.projection_size_qkvz, |
260 | 259 | bias=False, |
261 | 260 | quant_config=quant_config, |
262 | | - prefix=f"{prefix}.in_proj", |
| 261 | + prefix=f"{prefix}.in_proj_qkvz", |
| 262 | + ) |
| 263 | + # ba_proj doesn't support blockwise fp8 quantization. |
| 264 | + self.in_proj_ba = ColumnParallelLinear( |
| 265 | + input_size=self.hidden_size, |
| 266 | + output_size=self.projection_size_ba, |
| 267 | + bias=False, |
| 268 | + quant_config=quant_config, |
| 269 | + prefix=f"{prefix}.in_proj_ba", |
263 | 270 | ) |
264 | 271 |
|
265 | 272 | query_key_settings = (self.key_dim, 0, False) |
@@ -420,19 +427,14 @@ def _forward( |
420 | 427 | ssm_state = self_kv_cache[1] |
421 | 428 | num_actual_tokens = attn_metadata.num_actual_tokens |
422 | 429 | num_accepted_tokens = attn_metadata.num_accepted_tokens |
423 | | - |
424 | | - # 1. Set up dimensions for reshapes later |
425 | | - projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) |
426 | 430 | if spec_token_masks is not None: |
427 | 431 | spec_token_masks = spec_token_masks[:num_actual_tokens] |
428 | | - projected_states_qkvz, projected_states_ba = torch.split( |
429 | | - projected_states, |
430 | | - [ |
431 | | - self.projection_size_qkvz // self.tp_size, |
432 | | - self.projection_size_ba // self.tp_size |
433 | | - ], |
434 | | - dim=-1, |
435 | | - ) |
| 432 | + |
| 433 | + # 1. Set up dimensions for reshapes later |
| 434 | + projected_states_qkvz, _ = self.in_proj_qkvz( |
| 435 | + hidden_states[:num_actual_tokens]) |
| 436 | + projected_states_ba, _ = self.in_proj_ba( |
| 437 | + hidden_states[:num_actual_tokens]) |
436 | 438 | query, key, value, z, b, a = self.fix_query_key_value_ordering( |
437 | 439 | projected_states_qkvz, projected_states_ba) |
438 | 440 | query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), |
@@ -976,8 +978,6 @@ def load_weights(self, weights: Iterable[tuple[str, |
976 | 978 | ("qkv_proj", "v_proj", "v"), |
977 | 979 | ("gate_up_proj", "gate_proj", 0), |
978 | 980 | ("gate_up_proj", "up_proj", 1), |
979 | | - ("in_proj", "in_proj_qkvz", 0), |
980 | | - ("in_proj", "in_proj_ba", 1), |
981 | 981 | ] |
982 | 982 |
|
983 | 983 | params_dict = dict(self.named_parameters()) |
@@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, |
1055 | 1055 | "v_proj", |
1056 | 1056 | ], |
1057 | 1057 | "gate_up_proj": ["gate_proj", "up_proj"], |
1058 | | - "in_proj": ["in_proj_qkvz", "in_proj_ba"], |
1059 | 1058 | } |
1060 | 1059 |
|
1061 | 1060 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
0 commit comments