Skip to content

Commit 2887050

Browse files
poryflyunknown
andauthored
[Feature] add Qwen3MoE models for KTransformers-FT (#1602)
* add qwen3 attn * fix KQwen3MoeSparseMoeBlock * fix bug adapter for llamafactory --------- Co-authored-by: unknown <[email protected]>
1 parent ab8ad0a commit 2887050

File tree

5 files changed

+726
-9
lines changed

5 files changed

+726
-9
lines changed

KT-SFT/ktransformers/models/modeling_qwen3_moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ def forward(
206206
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
207207

208208
attention_interface: Callable = eager_attention_forward
209-
# if self.config._attn_implementation != "eager":
210-
# if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
211-
# logger.warning_once(
212-
# "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
213-
# 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
214-
# )
215-
# else:
216-
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
209+
if self.config._attn_implementation != "eager":
210+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
211+
logger.warning_once(
212+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
213+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
214+
)
215+
else:
216+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
217217

218218
attn_output, attn_weights = attention_interface(
219219
self,

KT-SFT/ktransformers/operators/attention.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
from ktransformers.models.configuration_llama import LlamaConfig
1414
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
1515
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
16-
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
16+
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention, Qwen3MoeRotaryEmbedding
1717
from typing import Optional, Tuple
1818
from ktransformers.operators.base_operator import BaseInjectedModule
1919
from ktransformers.util.custom_loader import GGUFLoader
2020
from ktransformers.util.utils import get_compute_capability
2121
import logging
2222
from transformers.configuration_utils import PretrainedConfig
2323
from transformers.cache_utils import Cache
24+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
2425
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
2526

2627
try:
@@ -943,3 +944,140 @@ def forward(
943944
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
944945
attn_output = self.o_proj(attn_output).to(input_dtype)
945946
return attn_output, attn_weights
947+
948+
949+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
950+
"""
951+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
952+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
953+
"""
954+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
955+
if n_rep == 1:
956+
return hidden_states
957+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
958+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
959+
960+
961+
def eager_attention_forward(
962+
module: nn.Module,
963+
query: torch.Tensor,
964+
key: torch.Tensor,
965+
value: torch.Tensor,
966+
attention_mask: Optional[torch.Tensor],
967+
scaling: float,
968+
dropout: float = 0.0,
969+
**kwargs,
970+
):
971+
key_states = repeat_kv(key, module.num_key_value_groups)
972+
value_states = repeat_kv(value, module.num_key_value_groups)
973+
974+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
975+
if attention_mask is not None:
976+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
977+
attn_weights = attn_weights + causal_mask
978+
979+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
980+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
981+
attn_output = torch.matmul(attn_weights, value_states)
982+
attn_output = attn_output.transpose(1, 2).contiguous()
983+
984+
return attn_output, attn_weights
985+
986+
987+
class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention ):
988+
def __init__(self,
989+
key: str,
990+
gguf_loader: GGUFLoader,
991+
config: PretrainedConfig,
992+
orig_module: nn.Module,
993+
prefill_device: str = "cuda",
994+
generate_device: str = "cuda",
995+
chunck_size: int = 1000,
996+
**kwargs):
997+
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device,
998+
**kwargs)
999+
self.orig_module.__init__(self.orig_module.config,
1000+
orig_module.layer_idx)
1001+
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
1002+
1003+
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
1004+
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
1005+
"""Applies Rotary Position Embedding to the query and key tensors.
1006+
1007+
Args:
1008+
q (`torch.Tensor`): The query tensor.
1009+
k (`torch.Tensor`): The key tensor.
1010+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
1011+
sin (`torch.Tensor`): The sine part of the rotary embedding.
1012+
position_ids (`torch.Tensor`):
1013+
Deprecated and unused.
1014+
unsqueeze_dim (`int`, *optional*, defaults to 1):
1015+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
1016+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
1017+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
1018+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
1019+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
1020+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
1021+
Returns:
1022+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
1023+
"""
1024+
cos = cos.unsqueeze(unsqueeze_dim)
1025+
sin = sin.unsqueeze(unsqueeze_dim)
1026+
q_embed = (q * cos) + (rotate_half(q) * sin)
1027+
k_embed = (k * cos) + (rotate_half(k) * sin)
1028+
return q_embed, k_embed
1029+
1030+
def forward(self,
1031+
hidden_states: torch.Tensor,
1032+
position_ids: Optional[torch.Tensor],
1033+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
1034+
attention_mask: Optional[torch.Tensor],
1035+
past_key_value: Optional[Cache] = None,
1036+
cache_position: Optional[torch.LongTensor] = None,
1037+
**kwargs
1038+
):
1039+
input_shape = hidden_states.shape[:-1]
1040+
hidden_shape = (*input_shape, -1, self.head_dim)
1041+
1042+
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
1043+
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
1044+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
1045+
1046+
if position_embeddings is None:
1047+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
1048+
1049+
cos, sin = position_embeddings
1050+
1051+
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
1052+
1053+
1054+
if past_key_value is not None:
1055+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
1056+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
1057+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1058+
1059+
attention_interface: Callable = eager_attention_forward
1060+
if self.config._attn_implementation != "eager":
1061+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
1062+
logger.warning_once(
1063+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
1064+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
1065+
)
1066+
else:
1067+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
1068+
1069+
attn_output, attn_weights = attention_interface(
1070+
self,
1071+
query_states,
1072+
key_states,
1073+
value_states,
1074+
attention_mask,
1075+
dropout=0.0 if not self.training else self.attention_dropout,
1076+
scaling=self.scaling,
1077+
sliding_window=self.sliding_window, # diff with Llama
1078+
**kwargs,
1079+
)
1080+
1081+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
1082+
attn_output = self.o_proj(attn_output)
1083+
return attn_output, attn_weights

KT-SFT/ktransformers/operators/experts.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,3 +2071,124 @@ def moe_infer(self, x, topk_ids, topk_weight):
20712071
.type(new_x.dtype)
20722072
)
20732073
return final_out
2074+
2075+
2076+
class KQwen3MoeSparseMoeBlock(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
2077+
def forward(self, hidden_states):
2078+
2079+
orig_shape = hidden_states.shape
2080+
sequence_length = orig_shape[1]
2081+
2082+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
2083+
2084+
router_logits = self.gate(hidden_states)
2085+
2086+
if router_logits.device.type == "xpu":
2087+
from ipex_llm.transformers.models.common import moe_softmax_topk
2088+
selected_experts, routing_weights = moe_softmax_topk(
2089+
router_logits.half(), self.top_k, self.norm_topk_prob
2090+
)
2091+
else:
2092+
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
2093+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
2094+
if self.norm_topk_prob:
2095+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
2096+
# we cast back to the input dtype
2097+
routing_weights = routing_weights.to(hidden_states.dtype)
2098+
2099+
# only for generate phase
2100+
if sequence_length == 1 and hasattr(self.experts.generate_experts,
2101+
"submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
2102+
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0],
2103+
routing_weights[0])
2104+
# y_ = self.shared_expert(hidden_states).squeeze(0)
2105+
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
2106+
2107+
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
2108+
2109+
# y += y_
2110+
y.resize_(*orig_shape)
2111+
return y
2112+
2113+
# y_ = self.shared_expert(hidden_states).squeeze(0)
2114+
# y_ = (
2115+
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
2116+
# )
2117+
2118+
if isinstance(self.experts, KExpertsBase):
2119+
y = self.moe_kexperts(hidden_states, selected_experts, routing_weights).view(*orig_shape).to(
2120+
device=hidden_states.device)
2121+
elif hidden_states.size(0) > 10:
2122+
# TODO may bugs here
2123+
y = (
2124+
self.moe_infer(hidden_states, selected_experts, routing_weights)
2125+
.view(*orig_shape)
2126+
.to(device=hidden_states.device)
2127+
)
2128+
else:
2129+
# TODO may bugs here
2130+
y = (
2131+
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
2132+
.view(*orig_shape)
2133+
.to(device=hidden_states.device)
2134+
)
2135+
# y += y_
2136+
return y
2137+
2138+
@maybe_no_grad()
2139+
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
2140+
outs = self.experts(x, topk_ids, topk_weight)
2141+
return outs
2142+
2143+
@maybe_no_grad()
2144+
# TODO may bugs here
2145+
def moe_infer_simple(
2146+
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
2147+
) -> torch.Tensor:
2148+
"""
2149+
x: [num_tokens, hidden_size]
2150+
topk_ids, topk_weight: [num_tokens, num_selected_experts]
2151+
"""
2152+
outs = torch.zeros_like(x)
2153+
for token_idx in range(topk_ids.size(0)):
2154+
for expert_idx in range(topk_ids.size(1)):
2155+
expert = self.experts[topk_ids[token_idx, expert_idx]]
2156+
outs[token_idx] += (
2157+
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
2158+
)
2159+
return outs
2160+
2161+
@maybe_no_grad()
2162+
# TODO may bugs here
2163+
def moe_infer(self, x, topk_ids, topk_weight):
2164+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
2165+
cnts.scatter_(1, topk_ids, 1)
2166+
tokens_per_expert = cnts.sum(dim=0)
2167+
idxs = topk_ids.view(-1).argsort()
2168+
sorted_tokens = x[idxs // topk_ids.shape[1]]
2169+
tokens_per_expert = tokens_per_expert.cpu().numpy()
2170+
2171+
outputs = []
2172+
start_idx = 0
2173+
for i, num_tokens in enumerate(tokens_per_expert):
2174+
end_idx = start_idx + num_tokens
2175+
if num_tokens == 0:
2176+
continue
2177+
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
2178+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
2179+
expert_out = expert.forward(tokens_for_this_expert)
2180+
outputs.append(expert_out)
2181+
start_idx = end_idx
2182+
2183+
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
2184+
2185+
new_x = torch.empty_like(outs)
2186+
new_x[idxs] = outs
2187+
final_out = (
2188+
new_x.view(*topk_ids.shape, -1)
2189+
.type(topk_weight.dtype)
2190+
.mul_(topk_weight.unsqueeze(dim=-1))
2191+
.sum(dim=1)
2192+
.type(new_x.dtype)
2193+
)
2194+
return final_out

0 commit comments

Comments
 (0)