diff --git a/README.md b/README.md index 77f9765..fdba758 100644 --- a/README.md +++ b/README.md @@ -132,11 +132,7 @@ vLLM supports offline batched inference or launching an OpenAI-Compatible API Se Since the Pull Request (PR) has not been submitted to the vLLM community at this stage, please prepare the environment by following the steps below: ```bash -git clone -b v0.10.0 https://github.com/vllm-project/vllm.git -cd vllm -wget https://raw.githubusercontent.com/inclusionAI/Ling-V2/refs/heads/main/inference/vllm/bailing_moe_v2.patch -git apply bailing_moe_v2.patch -pip install -e . +pip install vllm==0.11.0 ``` #### Offline Inference: @@ -149,7 +145,7 @@ tokenizer = AutoTokenizer.from_pretrained("inclusionAI/Ling-mini-2.0") sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=16384) -llm = LLM(model="inclusionAI/Ling-mini-2.0", dtype='bfloat16') +llm = LLM(model="inclusionAI/Ling-mini-2.0", dtype='bfloat16', trust_remote_code=True) prompt = "Give me a short introduction to large language models." messages = [ {"role": "system", "content": "You are Ling, an assistant created by inclusionAI"}, @@ -171,7 +167,7 @@ outputs = llm.generate([text], sampling_params) vllm serve inclusionAI/Ling-mini-2.0 \ --tensor-parallel-size 2 \ --pipeline-parallel-size 1 \ - --use-v2-block-manager \ + --trust-remote-code \ --gpu-memory-utilization 0.90 ``` diff --git a/inference/vllm/bailing_moe_v2.patch b/inference/vllm/bailing_moe_v2.patch deleted file mode 100644 index 7044dd9..0000000 --- a/inference/vllm/bailing_moe_v2.patch +++ /dev/null @@ -1,422 +0,0 @@ -diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py -index 853c13b13..2314f350e 100644 ---- a/vllm/model_executor/models/bailing_moe.py -+++ b/vllm/model_executor/models/bailing_moe.py -@@ -32,6 +32,7 @@ from torch import nn - from transformers.configuration_utils import PretrainedConfig - - from vllm.attention import Attention -+from vllm.compilation.decorators import support_torch_compile - from vllm.config import CacheConfig, VllmConfig - from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -@@ -66,6 +67,7 @@ class BailingAttention(nn.Module): - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, -+ reduce_results: bool = True, - prefix: str = "", - ): - super().__init__() -@@ -82,10 +84,11 @@ class BailingAttention(nn.Module): - self.head_dim = config.head_dim or (self.hidden_size // - self.total_num_heads) - self.q_size_per_rank = self.head_dim * self.num_heads -- - self.num_kv_heads = self.total_kv_heads // tp_size - self.kv_size_per_rank = self.num_kv_heads * self.head_dim - self.scale = self.head_dim**-0.5 -+ self.use_qk_norm = getattr(config, "use_qk_norm", False) -+ self.use_rmsnorm = getattr(config, "use_rmsnorm", False) - - self.query_key_value = QKVParallelLinear( - self.hidden_size, -@@ -97,30 +100,46 @@ class BailingAttention(nn.Module): - prefix=f"{prefix}.query_key_value", - ) - -+ if self.use_qk_norm: -+ self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm \ -+ else nn.LayerNorm(self.head_dim, eps=1e-6) -+ self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm \ -+ else nn.LayerNorm(self.head_dim, eps=1e-6) -+ - self.dense = RowParallelLinear( - self.total_num_heads * self.head_dim, - self.hidden_size, - bias=config.use_bias, - quant_config=quant_config, -+ reduce_results=reduce_results, - prefix=f"{prefix}.dense", - ) - -- self.attn = Attention(self.num_heads, -- self.head_dim, -- self.scale, -- num_kv_heads=self.num_kv_heads, -- cache_config=cache_config, -- prefix=f"{prefix}.attn") -+ if hasattr(config, "partial_rotary_factor"): -+ self.rotary_dim = int(self.head_dim * config.partial_rotary_factor) -+ elif hasattr(config, "rotary_dim"): -+ self.rotary_dim = config.rotary_dim -+ else: -+ self.rotary_dim = self.head_dim - - self.rotary_emb = get_rope( - self.head_dim, -- rotary_dim=self.head_dim, -+ rotary_dim=self.rotary_dim, - max_position=config.max_position_embeddings, - base=config.rope_theta, - is_neox_style=True, - rope_scaling=config.rope_scaling, - ) - -+ self.attn = Attention( -+ self.num_heads, -+ self.head_dim, -+ self.scale, -+ num_kv_heads=self.num_kv_heads, -+ cache_config=cache_config, -+ prefix=f"{prefix}.attn", -+ ) -+ - def forward( - self, - hidden_states: torch.Tensor, -@@ -133,6 +152,14 @@ class BailingAttention(nn.Module): - ], - dim=-1) - -+ if self.use_qk_norm: -+ q = q.view(-1, self.num_heads, self.head_dim) -+ k = k.view(-1, self.num_kv_heads, self.head_dim) -+ q = self.query_layernorm(q) -+ k = self.key_layernorm(k) -+ q = q.view(-1, self.q_size_per_rank) -+ k = k.view(-1, self.kv_size_per_rank) -+ - q, k = self.rotary_emb(position_ids, q, k) - - context_layer = self.attn(q, k, v) -@@ -196,44 +223,95 @@ class BailingMoE(nn.Module): - self.hidden_size = config.hidden_size - self.quant_config = quant_config - self.num_shared_experts = config.num_shared_experts -- # Gate always runs at half / full precision for now. -- self.gate = ReplicatedLinear(self.hidden_size, -- self.num_experts, -- bias=False, -- quant_config=None) -- -- self.experts = FusedMoE(num_experts=self.num_experts, -- top_k=self.top_k, -- hidden_size=self.hidden_size, -- intermediate_size=config.moe_intermediate_size, -- reduce_results=False, -- renormalize=self.norm_expert_prob, -- quant_config=quant_config, -- prefix=f"{prefix}.experts") -+ self.score_function = getattr(config, "score_function", None) -+ self.n_group = getattr(config, "n_group", None) -+ self.topk_group = getattr(config, "topk_group", None) -+ self.use_grouped_topk = (self.n_group is not None -+ and self.topk_group is not None) -+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) -+ -+ router_dtype = getattr(config, "router_dtype", None) -+ if router_dtype is None: -+ self.router_dtype = None -+ elif router_dtype == "fp32": -+ self.router_dtype = torch.float32 -+ else: -+ self.router_dtype = torch.bfloat16 -+ -+ self.gate = ReplicatedLinear( -+ self.hidden_size, -+ self.num_experts, -+ bias=False, -+ quant_config=None, -+ params_dtype=self.router_dtype, -+ ) -+ -+ if getattr(config, "moe_router_enable_expert_bias", False): -+ self.gate.expert_bias = nn.Parameter(torch.empty((config.num_experts,), dtype=torch.float32)) -+ else: -+ self.gate.expert_bias = None -+ -+ self.correction_bias = ( -+ self.gate.expert_bias.data if self.gate.expert_bias is not None else None -+ ) -+ -+ if self.score_function is not None: -+ assert ( -+ self.score_function == "softmax" and self.correction_bias is None -+ ) or ( -+ self.score_function == "sigmoid" and self.correction_bias is not None -+ ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" -+ else: -+ # default value for scoring_func -+ self.score_function = "softmax" -+ -+ self.experts = FusedMoE( -+ num_experts=self.num_experts, -+ top_k=self.top_k, -+ hidden_size=self.hidden_size, -+ intermediate_size=config.moe_intermediate_size, -+ reduce_results=False, -+ renormalize=self.norm_expert_prob, -+ quant_config=quant_config, -+ prefix=f"{prefix}.experts", -+ scoring_func=self.score_function, -+ e_score_correction_bias=self.gate.expert_bias, -+ num_expert_group=self.n_group, -+ topk_group=self.topk_group, -+ use_grouped_topk=self.use_grouped_topk, -+ ) - - if self.num_shared_experts > 0: -- intermediate_size = (config.moe_intermediate_size * -- self.num_shared_experts) -+ if hasattr(config, "moe_shared_expert_intermediate_size"): -+ intermediate_size = config.moe_shared_expert_intermediate_size -+ else: -+ intermediate_size = config.moe_intermediate_size -+ intermediate_size *= config.num_shared_experts - self.shared_experts = BailingMLP( - intermediate_size=intermediate_size, - config=config, - quant_config=quant_config, - reduce_results=False, -- prefix=f"{prefix}.shared_experts") -+ prefix=f"{prefix}.shared_experts" -+ ) - else: - self.shared_experts = None - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_size) -- if self.num_shared_experts > 0: -+ if self.shared_experts: - shared_output = self.shared_experts(hidden_states) - # router_logits: (num_tokens, n_experts) -- router_logits, _ = self.gate(hidden_states) -+ router_logits, _ = self.gate(hidden_states.to(self.router_dtype)) -+ router_logits = router_logits.to(hidden_states.dtype) -+ - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - -- if self.num_shared_experts > 0: -+ final_hidden_states *= self.routed_scaling_factor -+ -+ if self.shared_experts: - final_hidden_states = final_hidden_states + shared_output - - if self.tp_size > 1: -@@ -252,20 +330,28 @@ class BailingMoeBlock(nn.Module): - prefix: str = "", - ): - super().__init__() -+ layer_idx = int(prefix.split('.')[-1]) -+ self.config = config - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size -+ - self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) -- self.attention = BailingAttention(config, -- cache_config, -- quant_config, -- prefix=f"{prefix}.attention") -- self.post_attention_layernorm = RMSNorm(hidden_size, -- eps=config.rms_norm_eps) -- self.mlp = BailingMoE(intermediate_size, -- config, -- quant_config, -- True, -- prefix=f"{prefix}.mlp") -+ -+ self.attention = BailingAttention( -+ config, -+ cache_config, -+ quant_config, -+ prefix=f"{prefix}.attention", -+ ) -+ -+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) -+ -+ # Choose MLP class based on the number of experts and layer index -+ if layer_idx < config.first_k_dense_replace: -+ mlp_class = BailingMLP -+ else: -+ mlp_class = BailingMoE -+ self.mlp = mlp_class(intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp") - - def forward( - self, -@@ -291,6 +377,7 @@ class BailingMoeBlock(nn.Module): - return hidden_states, residual - - -+@support_torch_compile - class BailingMoeModel(nn.Module): - - def __init__( -@@ -307,11 +394,16 @@ class BailingMoeModel(nn.Module): - self.config = config - self.vocab_size = config.vocab_size - self.embed_dim = config.hidden_size -+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) - -- if get_pp_group().is_first_rank or (config.tie_word_embeddings -- and get_pp_group().is_last_rank): -+ if get_pp_group().is_first_rank or (self.tie_word_embeddings and -+ get_pp_group().is_last_rank): - self.word_embeddings = VocabParallelEmbedding( -- self.vocab_size, self.embed_dim) -+ self.vocab_size, -+ self.embed_dim, -+ quant_config=quant_config, -+ prefix=f"{prefix}.word_embeddings", -+ ) - else: - self.word_embeddings = PPMissingLayer() - -@@ -325,11 +417,14 @@ class BailingMoeModel(nn.Module): - quant_config=quant_config, - prefix=prefix, - ), -- prefix=f"{prefix}.layers") -+ prefix=f"{prefix}.layers" -+ ) - - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( -- ["hidden_states", "residual"], config.hidden_size)) -+ ["hidden_states", "residual"], config.hidden_size -+ ) -+ ) - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) -@@ -370,8 +465,11 @@ class BailingMoeModel(nn.Module): - "hidden_states": hidden_states, - "residual": residual - }) -- -- hidden_states, _ = self.norm(hidden_states, residual) -+ else: -+ if residual is None: -+ hidden_states = self.norm(hidden_states) -+ else: -+ hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: -@@ -394,7 +492,11 @@ class BailingMoeModel(nn.Module): - loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() - for name, loaded_weight in weights: -- if self.config.norm_head and "lm_head.weight" in name: -+ if ( -+ hasattr(self.config, "norm_head") -+ and self.config.norm_head -+ and "lm_head.weight" in name -+ ): - loaded_weight = F.normalize(loaded_weight, - dim=0, - p=2, -@@ -428,13 +530,17 @@ class BailingMoeModel(nn.Module): - - if is_pp_missing_parameter(name, self): - continue -+ if name not in params_dict: -+ continue - param = params_dict[name] - weight_loader = param.weight_loader -- weight_loader(param, -- loaded_weight, -- name, -- shard_id=shard_id, -- expert_id=expert_id) -+ weight_loader( -+ param, -+ loaded_weight, -+ name, -+ shard_id=shard_id, -+ expert_id=expert_id, -+ ) - break - else: - if name.endswith(".bias") and name not in params_dict: -@@ -472,24 +578,37 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - super().__init__() - - config = vllm_config.model_config.hf_config -+ if hasattr(config, "llm_config"): -+ config = config.llm_config -+ vllm_config.model_config.hf_config = config - quant_config = vllm_config.quant_config -+ lora_config = vllm_config.lora_config - - self.config = config -+ self.lora_config = lora_config - self.quant_config = quant_config - self.max_position_embeddings = config.max_position_embeddings - self.model = BailingMoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) -+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) -+ - if get_pp_group().is_last_rank: -- self.lm_head = (self.word_embeddings if config.tie_word_embeddings -- else ParallelLMHead(config.vocab_size, -- config.hidden_size, -- quant_config=quant_config)) -+ if self.tie_word_embeddings: -+ self.lm_head = self.model.word_embeddings -+ else: -+ self.lm_head = ParallelLMHead( -+ config.vocab_size, -+ config.hidden_size, -+ quant_config=quant_config, -+ prefix=f"{prefix}.lm_head", -+ ) - self.logits_processor = LogitsProcessor(config.vocab_size) - else: - self.lm_head = PPMissingLayer() - - self.make_empty_intermediate_tensors = ( -- self.model.make_empty_intermediate_tensors) -+ self.model.make_empty_intermediate_tensors -+ ) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) -@@ -519,9 +638,12 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] -- if self.config.tie_word_embeddings else None), -+ if self.tie_word_embeddings else None), - ) - return loader.load_weights(weights) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() -+ -+class BailingMoeV2ForCausalLM(BailingMoeForCausalLM): -+ pass -diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py -index 2aaac7798..fcf295e50 100644 ---- a/vllm/model_executor/models/registry.py -+++ b/vllm/model_executor/models/registry.py -@@ -43,6 +43,7 @@ _TEXT_GENERATION_MODELS = { - # baichuan-13b, lower case 'c' in the class name - "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), - "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), -+ "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"), - "BambaForCausalLM": ("bamba", "BambaForCausalLM"), - "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),