-
Notifications
You must be signed in to change notification settings - Fork 30.6k
add rotary kernel support to Qwen3 model #41147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
I made benchmark for |
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
…rmers into rotary-kernel
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this integration @kaixuanliu ! I left few nits to consider
global use_kernels | ||
use_kernels = getattr(self, "use_kernels", False) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to have an attention kwarg passed use_rotary_kernel
for example than defining a global variable like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean add a param called use_rotary_kernel
to kwargs here, and passed it down to Qwen3Attention
?
from ...cache_utils import Cache, DynamicCache | ||
from ...generation import GenerationMixin | ||
from ...integrations import use_kernel_forward_from_hub | ||
from ...integrations.hub_kernels import rotary_kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to lazily load the kernel, because here we are loading it before even knowing if the user wants to use kernels or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for your advice! Have updated related code
def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||
""" | ||
Rotary kernel implementation wrapper | ||
Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature | ||
""" | ||
cos = cos.unsqueeze(unsqueeze_dim) | ||
sin = sin.unsqueeze(unsqueeze_dim) | ||
|
||
q_rotated = q.clone() | ||
k_rotated = k.clone() | ||
|
||
# Get half dimension for rotation | ||
half_dim = q.shape[-1] // 2 | ||
q1 = q_rotated[..., :half_dim] | ||
q2 = q_rotated[..., half_dim:] | ||
k1 = k_rotated[..., :half_dim] | ||
k2 = k_rotated[..., half_dim:] | ||
if cos.shape[-1] != half_dim: | ||
# Trim cos/sin to match half_dim | ||
cos = cos[..., :half_dim] | ||
sin = sin[..., :half_dim] | ||
|
||
# Apply rotary embedding using our kernel | ||
rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you try to benchmark the performance with and without this kernel ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, on Intel XPU, one single rotary op needs 0.22 ms, and it drops to 0.1 ms after applying this patch. above 2x speedup.
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
…rmers into rotary-kernel
[For maintainers] Suggested jobs to run (before merge) run-slow: dots1, qwen3, qwen3_moe, qwen3_omni_moe |
No description provided.