diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..e5f89589442e 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -114,6 +114,8 @@ } register_kernel_mapping(_KERNEL_MAPPING) + # Preload the rotary kernel as it's used in many models. + rotary_kernel = get_kernel(repo_id="kernels-community/rotary") except ImportError: _kernels_available = False @@ -138,6 +140,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + rotary_kernel = None + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" @@ -201,4 +205,5 @@ def load_and_register_kernel(attn_implementation: str) -> None: "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "rotary_kernel", ] diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index ea500c064512..4f419e1f1cbb 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -170,6 +170,36 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Dots1Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -217,7 +247,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -580,6 +619,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + use_kernels = getattr(self, "use_kernels", False) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -588,6 +628,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81b16c4ee6b6..8f34c0ebc574 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -155,6 +155,36 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -202,7 +232,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -477,6 +516,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + use_kernels = getattr(self, "use_kernels", False) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -485,6 +525,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index f1e38841faf4..65babb1e3eb6 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -14,13 +14,13 @@ # limitations under the License. """PyTorch Qwen3 model.""" -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging @@ -49,6 +49,36 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3RMSNorm(Qwen2RMSNorm): pass @@ -82,7 +112,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -125,7 +164,16 @@ class Qwen3Model(Qwen2Model): class Qwen3ForCausalLM(Qwen2ForCausalLM): def forward( self, - **super_kwargs: Unpack[TransformersKwargs], + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -149,7 +197,35 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - return super().forward(**super_kwargs) + use_kernels = getattr(self, "use_kernels", False) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + use_kernels=use_kernels, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) class Qwen3ForSequenceClassification(Qwen2ForSequenceClassification): diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 2056e7c76a3a..174b2854e8af 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -119,6 +119,36 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -166,7 +196,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 1172ebf90919..8a1e0b786a27 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1397,6 +1397,36 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3OmniMoeThinkerTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1448,7 +1478,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -2324,7 +2363,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -3375,7 +3423,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache