Skip to content
5 changes: 5 additions & 0 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
}

register_kernel_mapping(_KERNEL_MAPPING)
# Preload the rotary kernel as it's used in many models.
rotary_kernel = get_kernel(repo_id="kernels-community/rotary")

except ImportError:
_kernels_available = False
Expand All @@ -138,6 +140,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")

rotary_kernel = None


def is_kernel(attn_implementation: Optional[str]) -> bool:
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
Expand Down Expand Up @@ -201,4 +205,5 @@ def load_and_register_kernel(attn_implementation: str) -> None:
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"rotary_kernel",
]
42 changes: 41 additions & 1 deletion src/transformers/models/qwen3/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations.hub_kernels import rotary_kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to lazily load the kernel, because here we are loading it before even knowing if the user wants to use kernels or not

Copy link
Contributor Author

@kaixuanliu kaixuanliu Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for your advice! Have updated related code

from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
Expand All @@ -46,6 +47,10 @@
from .configuration_qwen3 import Qwen3Config


# Global variable to track kernel usage, set by model instances
use_kernels = False


@use_kernel_forward_from_hub("RMSNorm")
class Qwen3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
Expand Down Expand Up @@ -117,6 +122,34 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed


def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
Rotary kernel implementation wrapper
Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

q_rotated = q.clone()
k_rotated = k.clone()

# Get half dimension for rotation
half_dim = q.shape[-1] // 2
q1 = q_rotated[..., :half_dim]
q2 = q_rotated[..., half_dim:]
k1 = k_rotated[..., :half_dim]
k2 = k_rotated[..., half_dim:]
if cos.shape[-1] != half_dim:
# Trim cos/sin to match half_dim
cos = cos[..., :half_dim]
sin = sin[..., :half_dim]

# Apply rotary embedding using our kernel
rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try to benchmark the performance with and without this kernel ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, on Intel XPU, one single rotary op needs 0.22 ms, and it drops to 0.1 ms after applying this patch. above 2x speedup.

rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return q_rotated, k_rotated


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
Expand Down Expand Up @@ -202,7 +235,10 @@ def forward(
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if rotary_kernel:
query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -477,6 +513,10 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
# Set global use_kernels flag based on model's kernel usage
global use_kernels
use_kernels = getattr(self, "use_kernels", False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have an attention kwarg passed use_rotary_kernel for example than defining a global variable like this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean add a param called use_rotary_kernel to kwargs here, and passed it down to Qwen3Attention?

outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down