Skip to content

Commit 72c99f2

Browse files
ant-yyIsotr0py
andauthored
[Model]: support Ling2.0 (vllm-project#24627)
Signed-off-by: vito.yy <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent bf214ca commit 72c99f2

File tree

4 files changed

+170
-50
lines changed

4 files changed

+170
-50
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ th {
328328
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
329329
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
330330
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
331+
| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
331332
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
332333
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ |
333334
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def check_available_online(
180180
trust_remote_code=True),
181181
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
182182
trust_remote_code=True),
183+
"BailingMoeV2ForCausalLM": _HfExamplesInfo("inclusionAI/Ling-mini-2.0",
184+
trust_remote_code=True),
183185
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
184186
min_transformers_version="4.55.3",
185187
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501

vllm/model_executor/models/bailing_moe.py

Lines changed: 166 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from vllm.model_executor.layers.layernorm import RMSNorm
4444
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
4545
QKVParallelLinear,
46-
ReplicatedLinear,
4746
RowParallelLinear)
4847
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4948
from vllm.model_executor.layers.quantization.base_config import (
@@ -68,6 +67,7 @@ def __init__(
6867
config: PretrainedConfig,
6968
cache_config: Optional[CacheConfig] = None,
7069
quant_config: Optional[QuantizationConfig] = None,
70+
reduce_results: bool = True,
7171
prefix: str = "",
7272
):
7373
super().__init__()
@@ -84,10 +84,11 @@ def __init__(
8484
self.head_dim = config.head_dim or (self.hidden_size //
8585
self.total_num_heads)
8686
self.q_size_per_rank = self.head_dim * self.num_heads
87-
8887
self.num_kv_heads = self.total_kv_heads // tp_size
8988
self.kv_size_per_rank = self.num_kv_heads * self.head_dim
9089
self.scale = self.head_dim**-0.5
90+
self.use_qk_norm = getattr(config, "use_qk_norm", False)
91+
self.use_rmsnorm = getattr(config, "use_rmsnorm", False)
9192

9293
self.query_key_value = QKVParallelLinear(
9394
self.hidden_size,
@@ -99,28 +100,45 @@ def __init__(
99100
prefix=f"{prefix}.query_key_value",
100101
)
101102

103+
if self.use_qk_norm:
104+
self.query_layernorm = (RMSNorm(
105+
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
106+
else nn.LayerNorm(self.head_dim, eps=1e-6))
107+
self.key_layernorm = (RMSNorm(
108+
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
109+
else nn.LayerNorm(self.head_dim, eps=1e-6))
110+
102111
self.dense = RowParallelLinear(
103112
self.total_num_heads * self.head_dim,
104113
self.hidden_size,
105114
bias=config.use_bias,
106115
quant_config=quant_config,
116+
reduce_results=reduce_results,
107117
prefix=f"{prefix}.dense",
108118
)
109119

110-
self.attn = Attention(self.num_heads,
111-
self.head_dim,
112-
self.scale,
113-
num_kv_heads=self.num_kv_heads,
114-
cache_config=cache_config,
115-
prefix=f"{prefix}.attn")
120+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
121+
1.0)
122+
123+
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
116124

117125
self.rotary_emb = get_rope(
118126
self.head_dim,
119-
rotary_dim=self.head_dim,
127+
rotary_dim=self.rotary_dim,
120128
max_position=config.max_position_embeddings,
121129
base=config.rope_theta,
122130
is_neox_style=True,
123131
rope_scaling=config.rope_scaling,
132+
partial_rotary_factor=self.partial_rotary_factor,
133+
)
134+
135+
self.attn = Attention(
136+
self.num_heads,
137+
self.head_dim,
138+
self.scale,
139+
num_kv_heads=self.num_kv_heads,
140+
cache_config=cache_config,
141+
prefix=f"{prefix}.attn",
124142
)
125143

126144
def forward(
@@ -135,6 +153,14 @@ def forward(
135153
],
136154
dim=-1)
137155

156+
if self.use_qk_norm:
157+
q = q.view(-1, self.num_heads, self.head_dim)
158+
k = k.view(-1, self.num_kv_heads, self.head_dim)
159+
q = self.query_layernorm(q)
160+
k = self.key_layernorm(k)
161+
q = q.view(-1, self.q_size_per_rank)
162+
k = k.view(-1, self.kv_size_per_rank)
163+
138164
q, k = self.rotary_emb(position_ids, q, k)
139165

140166
context_layer = self.attn(q, k, v)
@@ -198,24 +224,72 @@ def __init__(
198224
self.hidden_size = config.hidden_size
199225
self.quant_config = quant_config
200226
self.num_shared_experts = config.num_shared_experts
201-
# Gate always runs at half / full precision for now.
202-
self.gate = ReplicatedLinear(self.hidden_size,
203-
self.num_experts,
204-
bias=False,
205-
quant_config=None)
206-
207-
self.experts = FusedMoE(num_experts=self.num_experts,
208-
top_k=self.top_k,
209-
hidden_size=self.hidden_size,
210-
intermediate_size=config.moe_intermediate_size,
211-
reduce_results=False,
212-
renormalize=self.norm_expert_prob,
213-
quant_config=quant_config,
214-
prefix=f"{prefix}.experts")
227+
self.score_function = getattr(config, "score_function", None)
228+
self.n_group = getattr(config, "n_group", None)
229+
self.topk_group = getattr(config, "topk_group", None)
230+
self.use_grouped_topk = (self.n_group is not None
231+
and self.topk_group is not None)
232+
self.routed_scaling_factor = getattr(config, "routed_scaling_factor",
233+
1.0)
234+
235+
router_dtype = getattr(config, "router_dtype", None)
236+
if router_dtype is None:
237+
self.router_dtype = None
238+
elif router_dtype == "fp32":
239+
self.router_dtype = torch.float32
240+
else:
241+
self.router_dtype = torch.bfloat16
242+
243+
self.gate = nn.Linear(
244+
self.hidden_size,
245+
self.num_experts,
246+
bias=False,
247+
dtype=self.router_dtype,
248+
)
249+
250+
if getattr(config, "moe_router_enable_expert_bias", False):
251+
self.gate.expert_bias = nn.Parameter(
252+
torch.empty((config.num_experts, ), dtype=torch.float32))
253+
else:
254+
self.gate.expert_bias = None
255+
256+
self.correction_bias = (self.gate.expert_bias.data
257+
if self.gate.expert_bias is not None else None)
258+
259+
if self.score_function is not None:
260+
assert (
261+
self.score_function == "softmax"
262+
and self.correction_bias is None
263+
) or (
264+
self.score_function == "sigmoid"
265+
and self.correction_bias is not None
266+
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501
267+
else:
268+
# default value for scoring_func
269+
self.score_function = "softmax"
270+
271+
self.experts = FusedMoE(
272+
num_experts=self.num_experts,
273+
top_k=self.top_k,
274+
hidden_size=self.hidden_size,
275+
intermediate_size=config.moe_intermediate_size,
276+
reduce_results=False,
277+
renormalize=self.norm_expert_prob,
278+
quant_config=quant_config,
279+
prefix=f"{prefix}.experts",
280+
scoring_func=self.score_function,
281+
e_score_correction_bias=self.gate.expert_bias,
282+
num_expert_group=self.n_group,
283+
topk_group=self.topk_group,
284+
use_grouped_topk=self.use_grouped_topk,
285+
)
215286

216287
if self.num_shared_experts > 0:
217-
intermediate_size = (config.moe_intermediate_size *
218-
self.num_shared_experts)
288+
if hasattr(config, "moe_shared_expert_intermediate_size"):
289+
intermediate_size = config.moe_shared_expert_intermediate_size
290+
else:
291+
intermediate_size = config.moe_intermediate_size
292+
intermediate_size *= config.num_shared_experts
219293
self.shared_experts = BailingMLP(
220294
intermediate_size=intermediate_size,
221295
config=config,
@@ -228,14 +302,18 @@ def __init__(
228302
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
229303
num_tokens, hidden_size = hidden_states.shape
230304
hidden_states = hidden_states.view(-1, hidden_size)
231-
if self.num_shared_experts > 0:
305+
if self.shared_experts:
232306
shared_output = self.shared_experts(hidden_states)
233307
# router_logits: (num_tokens, n_experts)
234-
router_logits, _ = self.gate(hidden_states)
308+
router_logits = self.gate(hidden_states.to(self.router_dtype))
309+
router_logits = router_logits.to(hidden_states.dtype)
310+
235311
final_hidden_states = self.experts(hidden_states=hidden_states,
236312
router_logits=router_logits)
237313

238-
if self.num_shared_experts > 0:
314+
final_hidden_states *= self.routed_scaling_factor
315+
316+
if self.shared_experts:
239317
final_hidden_states = final_hidden_states + shared_output
240318

241319
if self.tp_size > 1:
@@ -254,20 +332,30 @@ def __init__(
254332
prefix: str = "",
255333
):
256334
super().__init__()
335+
layer_idx = int(prefix.split('.')[-1])
336+
self.config = config
257337
hidden_size = config.hidden_size
258338
intermediate_size = config.intermediate_size
339+
259340
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
260341
self.attention = BailingAttention(config,
261342
cache_config,
262343
quant_config,
263344
prefix=f"{prefix}.attention")
345+
264346
self.post_attention_layernorm = RMSNorm(hidden_size,
265347
eps=config.rms_norm_eps)
266-
self.mlp = BailingMoE(intermediate_size,
267-
config,
268-
quant_config,
269-
True,
270-
prefix=f"{prefix}.mlp")
348+
349+
# Choose MLP class based on the number of experts and layer index
350+
if layer_idx < config.first_k_dense_replace:
351+
mlp_class = BailingMLP
352+
else:
353+
mlp_class = BailingMoE
354+
self.mlp = mlp_class(intermediate_size,
355+
config,
356+
quant_config,
357+
True,
358+
prefix=f"{prefix}.mlp")
271359

272360
def forward(
273361
self,
@@ -310,11 +398,17 @@ def __init__(
310398
self.config = config
311399
self.vocab_size = config.vocab_size
312400
self.embed_dim = config.hidden_size
401+
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
402+
False)
313403

314-
if get_pp_group().is_first_rank or (config.tie_word_embeddings
404+
if get_pp_group().is_first_rank or (self.tie_word_embeddings
315405
and get_pp_group().is_last_rank):
316406
self.word_embeddings = VocabParallelEmbedding(
317-
self.vocab_size, self.embed_dim)
407+
self.vocab_size,
408+
self.embed_dim,
409+
quant_config=quant_config,
410+
prefix=f"{prefix}.word_embeddings",
411+
)
318412
else:
319413
self.word_embeddings = PPMissingLayer()
320414

@@ -372,8 +466,11 @@ def forward(
372466
"hidden_states": hidden_states,
373467
"residual": residual
374468
})
375-
376-
hidden_states, _ = self.norm(hidden_states, residual)
469+
else:
470+
if residual is None:
471+
hidden_states = self.norm(hidden_states)
472+
else:
473+
hidden_states, _ = self.norm(hidden_states, residual)
377474
return hidden_states
378475

379476
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@@ -396,7 +493,8 @@ def load_weights(self, weights: Iterable[tuple[str,
396493
loaded_params: set[str] = set()
397494
expert_params_mapping = self.get_expert_mapping()
398495
for name, loaded_weight in weights:
399-
if self.config.norm_head and "lm_head.weight" in name:
496+
if (hasattr(self.config, "norm_head") and self.config.norm_head
497+
and "lm_head.weight" in name):
400498
loaded_weight = F.normalize(loaded_weight,
401499
dim=0,
402500
p=2,
@@ -430,13 +528,17 @@ def load_weights(self, weights: Iterable[tuple[str,
430528

431529
if is_pp_missing_parameter(name, self):
432530
continue
531+
if name not in params_dict:
532+
continue
433533
param = params_dict[name]
434534
weight_loader = param.weight_loader
435-
weight_loader(param,
436-
loaded_weight,
437-
name,
438-
shard_id=shard_id,
439-
expert_id=expert_id)
535+
weight_loader(
536+
param,
537+
loaded_weight,
538+
name,
539+
shard_id=shard_id,
540+
expert_id=expert_id,
541+
)
440542
break
441543
else:
442544
if name.endswith(".bias") and name not in params_dict:
@@ -473,19 +575,30 @@ def __init__(
473575
) -> None:
474576
super().__init__()
475577

476-
config = vllm_config.model_config.hf_config
578+
config = vllm_config.model_config.hf_config.get_text_config()
579+
vllm_config.model_config.hf_config = config
477580
quant_config = vllm_config.quant_config
581+
lora_config = vllm_config.lora_config
478582

479583
self.config = config
584+
self.lora_config = lora_config
480585
self.quant_config = quant_config
481586
self.max_position_embeddings = config.max_position_embeddings
482587
self.model = BailingMoeModel(vllm_config=vllm_config,
483588
prefix=maybe_prefix(prefix, "model"))
589+
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
590+
False)
591+
484592
if get_pp_group().is_last_rank:
485-
self.lm_head = (self.word_embeddings if config.tie_word_embeddings
486-
else ParallelLMHead(config.vocab_size,
487-
config.hidden_size,
488-
quant_config=quant_config))
593+
if self.tie_word_embeddings:
594+
self.lm_head = self.model.word_embeddings
595+
else:
596+
self.lm_head = ParallelLMHead(
597+
config.vocab_size,
598+
config.hidden_size,
599+
quant_config=quant_config,
600+
prefix=f"{prefix}.lm_head",
601+
)
489602
self.logits_processor = LogitsProcessor(config.vocab_size)
490603
else:
491604
self.lm_head = PPMissingLayer()
@@ -520,10 +633,13 @@ def load_weights(self, weights: Iterable[tuple[str,
520633
torch.Tensor]]) -> set[str]:
521634
loader = AutoWeightsLoader(
522635
self,
523-
skip_prefixes=(["lm_head."]
524-
if self.config.tie_word_embeddings else None),
636+
skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),
525637
)
526638
return loader.load_weights(weights)
527639

528640
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
529641
return self.model.get_expert_mapping()
642+
643+
644+
class BailingMoeV2ForCausalLM(BailingMoeForCausalLM):
645+
pass

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
# baichuan-13b, lower case 'c' in the class name
5353
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
5454
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
55+
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
5556
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
5657
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
5758
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),

0 commit comments

Comments
 (0)