4343from vllm .model_executor .layers .layernorm import RMSNorm
4444from vllm .model_executor .layers .linear import (MergedColumnParallelLinear ,
4545 QKVParallelLinear ,
46- ReplicatedLinear ,
4746 RowParallelLinear )
4847from vllm .model_executor .layers .logits_processor import LogitsProcessor
4948from vllm .model_executor .layers .quantization .base_config import (
@@ -68,6 +67,7 @@ def __init__(
6867 config : PretrainedConfig ,
6968 cache_config : Optional [CacheConfig ] = None ,
7069 quant_config : Optional [QuantizationConfig ] = None ,
70+ reduce_results : bool = True ,
7171 prefix : str = "" ,
7272 ):
7373 super ().__init__ ()
@@ -84,10 +84,11 @@ def __init__(
8484 self .head_dim = config .head_dim or (self .hidden_size //
8585 self .total_num_heads )
8686 self .q_size_per_rank = self .head_dim * self .num_heads
87-
8887 self .num_kv_heads = self .total_kv_heads // tp_size
8988 self .kv_size_per_rank = self .num_kv_heads * self .head_dim
9089 self .scale = self .head_dim ** - 0.5
90+ self .use_qk_norm = getattr (config , "use_qk_norm" , False )
91+ self .use_rmsnorm = getattr (config , "use_rmsnorm" , False )
9192
9293 self .query_key_value = QKVParallelLinear (
9394 self .hidden_size ,
@@ -99,28 +100,45 @@ def __init__(
99100 prefix = f"{ prefix } .query_key_value" ,
100101 )
101102
103+ if self .use_qk_norm :
104+ self .query_layernorm = (RMSNorm (
105+ self .head_dim , eps = config .rms_norm_eps ) if self .use_rmsnorm
106+ else nn .LayerNorm (self .head_dim , eps = 1e-6 ))
107+ self .key_layernorm = (RMSNorm (
108+ self .head_dim , eps = config .rms_norm_eps ) if self .use_rmsnorm
109+ else nn .LayerNorm (self .head_dim , eps = 1e-6 ))
110+
102111 self .dense = RowParallelLinear (
103112 self .total_num_heads * self .head_dim ,
104113 self .hidden_size ,
105114 bias = config .use_bias ,
106115 quant_config = quant_config ,
116+ reduce_results = reduce_results ,
107117 prefix = f"{ prefix } .dense" ,
108118 )
109119
110- self .attn = Attention (self .num_heads ,
111- self .head_dim ,
112- self .scale ,
113- num_kv_heads = self .num_kv_heads ,
114- cache_config = cache_config ,
115- prefix = f"{ prefix } .attn" )
120+ self .partial_rotary_factor = getattr (config , "partial_rotary_factor" ,
121+ 1.0 )
122+
123+ self .rotary_dim = getattr (config , "rotary_dim" , self .head_dim )
116124
117125 self .rotary_emb = get_rope (
118126 self .head_dim ,
119- rotary_dim = self .head_dim ,
127+ rotary_dim = self .rotary_dim ,
120128 max_position = config .max_position_embeddings ,
121129 base = config .rope_theta ,
122130 is_neox_style = True ,
123131 rope_scaling = config .rope_scaling ,
132+ partial_rotary_factor = self .partial_rotary_factor ,
133+ )
134+
135+ self .attn = Attention (
136+ self .num_heads ,
137+ self .head_dim ,
138+ self .scale ,
139+ num_kv_heads = self .num_kv_heads ,
140+ cache_config = cache_config ,
141+ prefix = f"{ prefix } .attn" ,
124142 )
125143
126144 def forward (
@@ -135,6 +153,14 @@ def forward(
135153 ],
136154 dim = - 1 )
137155
156+ if self .use_qk_norm :
157+ q = q .view (- 1 , self .num_heads , self .head_dim )
158+ k = k .view (- 1 , self .num_kv_heads , self .head_dim )
159+ q = self .query_layernorm (q )
160+ k = self .key_layernorm (k )
161+ q = q .view (- 1 , self .q_size_per_rank )
162+ k = k .view (- 1 , self .kv_size_per_rank )
163+
138164 q , k = self .rotary_emb (position_ids , q , k )
139165
140166 context_layer = self .attn (q , k , v )
@@ -198,24 +224,72 @@ def __init__(
198224 self .hidden_size = config .hidden_size
199225 self .quant_config = quant_config
200226 self .num_shared_experts = config .num_shared_experts
201- # Gate always runs at half / full precision for now.
202- self .gate = ReplicatedLinear (self .hidden_size ,
203- self .num_experts ,
204- bias = False ,
205- quant_config = None )
206-
207- self .experts = FusedMoE (num_experts = self .num_experts ,
208- top_k = self .top_k ,
209- hidden_size = self .hidden_size ,
210- intermediate_size = config .moe_intermediate_size ,
211- reduce_results = False ,
212- renormalize = self .norm_expert_prob ,
213- quant_config = quant_config ,
214- prefix = f"{ prefix } .experts" )
227+ self .score_function = getattr (config , "score_function" , None )
228+ self .n_group = getattr (config , "n_group" , None )
229+ self .topk_group = getattr (config , "topk_group" , None )
230+ self .use_grouped_topk = (self .n_group is not None
231+ and self .topk_group is not None )
232+ self .routed_scaling_factor = getattr (config , "routed_scaling_factor" ,
233+ 1.0 )
234+
235+ router_dtype = getattr (config , "router_dtype" , None )
236+ if router_dtype is None :
237+ self .router_dtype = None
238+ elif router_dtype == "fp32" :
239+ self .router_dtype = torch .float32
240+ else :
241+ self .router_dtype = torch .bfloat16
242+
243+ self .gate = nn .Linear (
244+ self .hidden_size ,
245+ self .num_experts ,
246+ bias = False ,
247+ dtype = self .router_dtype ,
248+ )
249+
250+ if getattr (config , "moe_router_enable_expert_bias" , False ):
251+ self .gate .expert_bias = nn .Parameter (
252+ torch .empty ((config .num_experts , ), dtype = torch .float32 ))
253+ else :
254+ self .gate .expert_bias = None
255+
256+ self .correction_bias = (self .gate .expert_bias .data
257+ if self .gate .expert_bias is not None else None )
258+
259+ if self .score_function is not None :
260+ assert (
261+ self .score_function == "softmax"
262+ and self .correction_bias is None
263+ ) or (
264+ self .score_function == "sigmoid"
265+ and self .correction_bias is not None
266+ ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501
267+ else :
268+ # default value for scoring_func
269+ self .score_function = "softmax"
270+
271+ self .experts = FusedMoE (
272+ num_experts = self .num_experts ,
273+ top_k = self .top_k ,
274+ hidden_size = self .hidden_size ,
275+ intermediate_size = config .moe_intermediate_size ,
276+ reduce_results = False ,
277+ renormalize = self .norm_expert_prob ,
278+ quant_config = quant_config ,
279+ prefix = f"{ prefix } .experts" ,
280+ scoring_func = self .score_function ,
281+ e_score_correction_bias = self .gate .expert_bias ,
282+ num_expert_group = self .n_group ,
283+ topk_group = self .topk_group ,
284+ use_grouped_topk = self .use_grouped_topk ,
285+ )
215286
216287 if self .num_shared_experts > 0 :
217- intermediate_size = (config .moe_intermediate_size *
218- self .num_shared_experts )
288+ if hasattr (config , "moe_shared_expert_intermediate_size" ):
289+ intermediate_size = config .moe_shared_expert_intermediate_size
290+ else :
291+ intermediate_size = config .moe_intermediate_size
292+ intermediate_size *= config .num_shared_experts
219293 self .shared_experts = BailingMLP (
220294 intermediate_size = intermediate_size ,
221295 config = config ,
@@ -228,14 +302,18 @@ def __init__(
228302 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
229303 num_tokens , hidden_size = hidden_states .shape
230304 hidden_states = hidden_states .view (- 1 , hidden_size )
231- if self .num_shared_experts > 0 :
305+ if self .shared_experts :
232306 shared_output = self .shared_experts (hidden_states )
233307 # router_logits: (num_tokens, n_experts)
234- router_logits , _ = self .gate (hidden_states )
308+ router_logits = self .gate (hidden_states .to (self .router_dtype ))
309+ router_logits = router_logits .to (hidden_states .dtype )
310+
235311 final_hidden_states = self .experts (hidden_states = hidden_states ,
236312 router_logits = router_logits )
237313
238- if self .num_shared_experts > 0 :
314+ final_hidden_states *= self .routed_scaling_factor
315+
316+ if self .shared_experts :
239317 final_hidden_states = final_hidden_states + shared_output
240318
241319 if self .tp_size > 1 :
@@ -254,20 +332,30 @@ def __init__(
254332 prefix : str = "" ,
255333 ):
256334 super ().__init__ ()
335+ layer_idx = int (prefix .split ('.' )[- 1 ])
336+ self .config = config
257337 hidden_size = config .hidden_size
258338 intermediate_size = config .intermediate_size
339+
259340 self .input_layernorm = RMSNorm (hidden_size , eps = config .rms_norm_eps )
260341 self .attention = BailingAttention (config ,
261342 cache_config ,
262343 quant_config ,
263344 prefix = f"{ prefix } .attention" )
345+
264346 self .post_attention_layernorm = RMSNorm (hidden_size ,
265347 eps = config .rms_norm_eps )
266- self .mlp = BailingMoE (intermediate_size ,
267- config ,
268- quant_config ,
269- True ,
270- prefix = f"{ prefix } .mlp" )
348+
349+ # Choose MLP class based on the number of experts and layer index
350+ if layer_idx < config .first_k_dense_replace :
351+ mlp_class = BailingMLP
352+ else :
353+ mlp_class = BailingMoE
354+ self .mlp = mlp_class (intermediate_size ,
355+ config ,
356+ quant_config ,
357+ True ,
358+ prefix = f"{ prefix } .mlp" )
271359
272360 def forward (
273361 self ,
@@ -310,11 +398,17 @@ def __init__(
310398 self .config = config
311399 self .vocab_size = config .vocab_size
312400 self .embed_dim = config .hidden_size
401+ self .tie_word_embeddings = getattr (config , "tie_word_embeddings" ,
402+ False )
313403
314- if get_pp_group ().is_first_rank or (config .tie_word_embeddings
404+ if get_pp_group ().is_first_rank or (self .tie_word_embeddings
315405 and get_pp_group ().is_last_rank ):
316406 self .word_embeddings = VocabParallelEmbedding (
317- self .vocab_size , self .embed_dim )
407+ self .vocab_size ,
408+ self .embed_dim ,
409+ quant_config = quant_config ,
410+ prefix = f"{ prefix } .word_embeddings" ,
411+ )
318412 else :
319413 self .word_embeddings = PPMissingLayer ()
320414
@@ -372,8 +466,11 @@ def forward(
372466 "hidden_states" : hidden_states ,
373467 "residual" : residual
374468 })
375-
376- hidden_states , _ = self .norm (hidden_states , residual )
469+ else :
470+ if residual is None :
471+ hidden_states = self .norm (hidden_states )
472+ else :
473+ hidden_states , _ = self .norm (hidden_states , residual )
377474 return hidden_states
378475
379476 def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
@@ -396,7 +493,8 @@ def load_weights(self, weights: Iterable[tuple[str,
396493 loaded_params : set [str ] = set ()
397494 expert_params_mapping = self .get_expert_mapping ()
398495 for name , loaded_weight in weights :
399- if self .config .norm_head and "lm_head.weight" in name :
496+ if (hasattr (self .config , "norm_head" ) and self .config .norm_head
497+ and "lm_head.weight" in name ):
400498 loaded_weight = F .normalize (loaded_weight ,
401499 dim = 0 ,
402500 p = 2 ,
@@ -430,13 +528,17 @@ def load_weights(self, weights: Iterable[tuple[str,
430528
431529 if is_pp_missing_parameter (name , self ):
432530 continue
531+ if name not in params_dict :
532+ continue
433533 param = params_dict [name ]
434534 weight_loader = param .weight_loader
435- weight_loader (param ,
436- loaded_weight ,
437- name ,
438- shard_id = shard_id ,
439- expert_id = expert_id )
535+ weight_loader (
536+ param ,
537+ loaded_weight ,
538+ name ,
539+ shard_id = shard_id ,
540+ expert_id = expert_id ,
541+ )
440542 break
441543 else :
442544 if name .endswith (".bias" ) and name not in params_dict :
@@ -473,19 +575,30 @@ def __init__(
473575 ) -> None :
474576 super ().__init__ ()
475577
476- config = vllm_config .model_config .hf_config
578+ config = vllm_config .model_config .hf_config .get_text_config ()
579+ vllm_config .model_config .hf_config = config
477580 quant_config = vllm_config .quant_config
581+ lora_config = vllm_config .lora_config
478582
479583 self .config = config
584+ self .lora_config = lora_config
480585 self .quant_config = quant_config
481586 self .max_position_embeddings = config .max_position_embeddings
482587 self .model = BailingMoeModel (vllm_config = vllm_config ,
483588 prefix = maybe_prefix (prefix , "model" ))
589+ self .tie_word_embeddings = getattr (config , "tie_word_embeddings" ,
590+ False )
591+
484592 if get_pp_group ().is_last_rank :
485- self .lm_head = (self .word_embeddings if config .tie_word_embeddings
486- else ParallelLMHead (config .vocab_size ,
487- config .hidden_size ,
488- quant_config = quant_config ))
593+ if self .tie_word_embeddings :
594+ self .lm_head = self .model .word_embeddings
595+ else :
596+ self .lm_head = ParallelLMHead (
597+ config .vocab_size ,
598+ config .hidden_size ,
599+ quant_config = quant_config ,
600+ prefix = f"{ prefix } .lm_head" ,
601+ )
489602 self .logits_processor = LogitsProcessor (config .vocab_size )
490603 else :
491604 self .lm_head = PPMissingLayer ()
@@ -520,10 +633,13 @@ def load_weights(self, weights: Iterable[tuple[str,
520633 torch .Tensor ]]) -> set [str ]:
521634 loader = AutoWeightsLoader (
522635 self ,
523- skip_prefixes = (["lm_head." ]
524- if self .config .tie_word_embeddings else None ),
636+ skip_prefixes = (["lm_head." ] if self .tie_word_embeddings else None ),
525637 )
526638 return loader .load_weights (weights )
527639
528640 def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
529641 return self .model .get_expert_mapping ()
642+
643+
644+ class BailingMoeV2ForCausalLM (BailingMoeForCausalLM ):
645+ pass
0 commit comments