-
Notifications
You must be signed in to change notification settings - Fork 30.7k
add rotary kernel support to Qwen3 model #41147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
69f2ca8
d2bf5c5
b0cbab5
8dede65
5c02189
137069b
8ac3e1e
29f83f2
7729b7f
94e4f60
b96a7c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
from ...cache_utils import Cache, DynamicCache | ||
from ...generation import GenerationMixin | ||
from ...integrations import use_kernel_forward_from_hub | ||
from ...integrations.hub_kernels import rotary_kernel | ||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask | ||
from ...modeling_flash_attention_utils import FlashAttentionKwargs | ||
from ...modeling_layers import ( | ||
|
@@ -46,6 +47,10 @@ | |
from .configuration_qwen3 import Qwen3Config | ||
|
||
|
||
# Global variable to track kernel usage, set by model instances | ||
use_kernels = False | ||
|
||
|
||
@use_kernel_forward_from_hub("RMSNorm") | ||
class Qwen3RMSNorm(nn.Module): | ||
def __init__(self, hidden_size, eps: float = 1e-6) -> None: | ||
|
@@ -117,6 +122,34 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |
return q_embed, k_embed | ||
|
||
|
||
def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||
""" | ||
Rotary kernel implementation wrapper | ||
Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature | ||
""" | ||
cos = cos.unsqueeze(unsqueeze_dim) | ||
sin = sin.unsqueeze(unsqueeze_dim) | ||
|
||
q_rotated = q.clone() | ||
k_rotated = k.clone() | ||
|
||
# Get half dimension for rotation | ||
half_dim = q.shape[-1] // 2 | ||
q1 = q_rotated[..., :half_dim] | ||
q2 = q_rotated[..., half_dim:] | ||
k1 = k_rotated[..., :half_dim] | ||
k2 = k_rotated[..., half_dim:] | ||
if cos.shape[-1] != half_dim: | ||
# Trim cos/sin to match half_dim | ||
cos = cos[..., :half_dim] | ||
sin = sin[..., :half_dim] | ||
|
||
# Apply rotary embedding using our kernel | ||
rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) | ||
|
||
rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) | ||
return q_rotated, k_rotated | ||
|
||
|
||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
""" | ||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | ||
|
@@ -202,7 +235,10 @@ def forward( | |
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
|
||
cos, sin = position_embeddings | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
if rotary_kernel: | ||
query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) | ||
else: | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
||
if past_key_values is not None: | ||
# sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
|
@@ -477,6 +513,10 @@ def forward( | |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||
```""" | ||
# Set global use_kernels flag based on model's kernel usage | ||
global use_kernels | ||
use_kernels = getattr(self, "use_kernels", False) | ||
|
||
|
||
outputs: BaseModelOutputWithPast = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to lazily load the kernel, because here we are loading it before even knowing if the user wants to use kernels or not
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for your advice! Have updated related code