Skip to content

Commit c104388

Browse files
bobbolihlu1
andauthored
chore: Refactor apply_rope. (#4918)
Signed-off-by: Bo Li <[email protected]> Co-authored-by: hlu1 <[email protected]>
1 parent 6b17dff commit c104388

File tree

4 files changed

+77
-66
lines changed

4 files changed

+77
-66
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
PredefinedAttentionMask, RopeParams)
1515
from ..distributed import AllReduceParams
1616
from ..model_config import ModelConfig
17-
from ..modules.attention import Attention, QkNormType
17+
from ..modules.attention import Attention
1818
from ..modules.decoder_layer import DecoderLayer
1919
from ..modules.embedding import Embedding
2020
from ..modules.linear import Linear, TensorParallelMode
@@ -53,7 +53,6 @@ def __init__(
5353
max_position_embeddings=config.max_position_embeddings,
5454
bias=False,
5555
pos_embd_params=pos_embd_params,
56-
qk_norm_type=QkNormType.pre_rope,
5756
layer_idx=layer_idx,
5857
dtype=config.torch_dtype,
5958
dense_bias=False,
@@ -113,6 +112,13 @@ def k_l2norm():
113112

114113
return q, k
115114

115+
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
116+
v: Optional[torch.Tensor], position_ids: torch.Tensor):
117+
# Gemma3 applies QK norm before RoPE.
118+
q, k, v = self.split_qkv(q, k, v)
119+
q, k = self.apply_qk_norm(q, k)
120+
return super().apply_rope(q, k, v, position_ids)
121+
116122

117123
class Gemma3MLP(nn.Module):
118124

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..attention_backend.interface import (PositionalEmbeddingParams,
2424
PredefinedAttentionMask, RopeParams)
2525
from ..model_config import ModelConfig
26-
from ..modules.attention import Attention, QkNormType
26+
from ..modules.attention import Attention
2727
from ..modules.decoder_layer import DecoderLayer
2828
from ..modules.embedding import Embedding
2929
from ..modules.fused_moe import (Llama4RenormalizeMoeRoutingMethod,
@@ -60,6 +60,7 @@ def __init__(
6060
rope=RopeParams.from_config(config),
6161
is_neox=False,
6262
) if self.use_rope else None
63+
self.use_qk_norm = use_qk_norm
6364

6465
if model_config.attn_backend != "TRTLLM":
6566
# TODO: support chunked attention for other backends.
@@ -74,15 +75,15 @@ def __init__(
7475
max_position_embeddings=config.max_position_embeddings,
7576
bias=config.attention_bias,
7677
pos_embd_params=pos_embd_params,
77-
qk_norm_type=QkNormType.post_rope
78-
if use_qk_norm else QkNormType.none,
78+
rope_fusion=not self.
79+
use_qk_norm, # Llama4 uses qk_norm after RoPE, so it is not possible to fuse RoPE into the attention OP with qk_norm.
7980
layer_idx=layer_idx,
8081
dtype=config.torch_dtype,
8182
config=model_config,
8283
attention_chunk_size=attention_chunk_size,
8384
)
8485

85-
if self.use_rope and use_qk_norm:
86+
if self.use_qk_norm:
8687
self.head_dim = config.hidden_size // config.num_attention_heads
8788
self.qk_norm = RMSNorm(hidden_size=self.head_dim,
8889
eps=1e-6,
@@ -115,6 +116,17 @@ def k_l2norm():
115116

116117
return q, k
117118

119+
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
120+
v: Optional[torch.Tensor], position_ids: torch.Tensor):
121+
q, k, v = self.split_qkv(q, k, v)
122+
if position_ids is not None:
123+
q, k, v = super().apply_rope(q, k, v, position_ids)
124+
# Llama4 applies QK norm after RoPE.
125+
if self.use_qk_norm:
126+
q, k = self.apply_qk_norm(q, k)
127+
128+
return q, k, v
129+
118130
def _attention_scaling(self, q, position_ids):
119131

120132
def _get_attn_scale(position_ids: torch.Tensor) -> torch.Tensor:

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..attention_backend import AttentionMetadata
1010
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
1111
from ..model_config import ModelConfig
12-
from ..modules.attention import Attention, QkNormType
12+
from ..modules.attention import Attention
1313
from ..modules.decoder_layer import DecoderLayer
1414
from ..modules.embedding import Embedding
1515
from ..modules.gated_mlp import GatedMLP
@@ -50,19 +50,15 @@ def __init__(
5050
num_key_value_heads=config.num_key_value_heads,
5151
max_position_embeddings=config.max_position_embeddings,
5252
bias=config.attention_bias,
53-
pos_embd_params=pos_embd_params
54-
if not self.fuse_qk_norm_rope else None,
55-
qk_norm_type=QkNormType.pre_rope,
53+
pos_embd_params=pos_embd_params,
54+
rope_fusion=not self.
55+
fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
5656
layer_idx=layer_idx,
5757
dtype=config.torch_dtype,
5858
dense_bias=config.attention_bias,
5959
config=model_config,
6060
)
6161

62-
# If fuse_qk_norm_rope is true, we pass pos_embd_params=None to super().__init__,
63-
# so we need to do assignment to record the actual pos_embd_params.
64-
self.pos_embd_params = pos_embd_params
65-
6662
self.q_norm = RMSNorm(hidden_size=self.head_dim,
6763
eps=1e-6,
6864
dtype=config.torch_dtype,
@@ -94,12 +90,6 @@ def k_l2norm():
9490

9591
return q, k
9692

97-
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
98-
if not self.fuse_qk_norm_rope:
99-
return super().apply_rope(qkv, position_ids)
100-
else:
101-
return self.apply_qk_norm_rope(qkv, position_ids)
102-
10393
def apply_qk_norm_rope(self, qkv, position_ids):
10494
torch.ops.trtllm.fused_qk_norm_rope(
10595
qkv, self.num_heads, self.num_key_value_heads,
@@ -109,6 +99,18 @@ def apply_qk_norm_rope(self, qkv, position_ids):
10999
self.pos_embd_params.is_neox, position_ids.view(-1))
110100
return qkv, None, None
111101

102+
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
103+
v: Optional[torch.Tensor], position_ids: torch.Tensor):
104+
# Qwen3 applies QK norm before RoPE.
105+
if not self.fuse_qk_norm_rope:
106+
q, k, v = self.split_qkv(q, k, v)
107+
q, k = self.apply_qk_norm(q, k)
108+
return super().apply_rope(q, k, v, position_ids)
109+
110+
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
111+
qkv = q
112+
return self.apply_qk_norm_rope(qkv, position_ids)
113+
112114

113115
class Qwen3DecoderLayer(DecoderLayer):
114116

tensorrt_llm/_torch/modules/attention.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import math
22
import weakref
3-
from enum import IntEnum
43
from typing import Optional, Union, cast
54

65
import torch
76
from torch import nn
87

8+
from tensorrt_llm.logger import logger
99
from tensorrt_llm.mapping import Mapping
1010

1111
from ..attention_backend import (AttentionInputType, AttentionMetadata,
@@ -23,15 +23,6 @@
2323
from .rotary_embedding import RotaryEmbedding
2424

2525

26-
class QkNormType(IntEnum):
27-
"""
28-
The type of QK normalization.
29-
"""
30-
none = 0 # No normalization applied to Q and K
31-
pre_rope = 1 # Apply normalization before Rope
32-
post_rope = 2 # Apply normalization after Rope
33-
34-
3526
class Attention(nn.Module):
3627

3728
def __init__(
@@ -43,7 +34,7 @@ def __init__(
4334
max_position_embeddings: int,
4435
bias: bool,
4536
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
46-
qk_norm_type: QkNormType = QkNormType.none,
37+
rope_fusion: Optional[bool] = None,
4738
layer_idx: Optional[int] = None,
4839
dtype: torch.dtype = None,
4940
dense_bias: Optional[bool] = None,
@@ -60,14 +51,14 @@ def __init__(
6051
num_key_value_heads (int): The number of key value heads.
6152
max_position_embeddings (int): The maximum position embeddings.
6253
bias (bool): Whether to use bias in the linear layers.
63-
pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters.
64-
qk_norm_type (QkNormType): The type of QK normalization.
65-
layer_idx (int): The layer index.
54+
pos_embd_params (Optional[PositionalEmbeddingParams]): The positional embedding parameters.
55+
rope_fusion (Optional[bool]): Whether to fuse RoPE into the attention OP and skip applying unfused RoPE. If None, whether to fuse is decided by the capability of the attention backend.
56+
layer_idx (Optional[int]): The layer index.
6657
dtype (torch.dtype): The data type.
67-
dense_bias (bool): Whether to use bias in the output projection layer.
68-
config (ModelConfig): The model configuration.
58+
dense_bias (Optional[bool]): Whether to use bias in the output projection layer.
59+
config (Optional[ModelConfig]): The model configuration.
6960
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
70-
attention_chunk_size (int): See [Chunked Attention] below.
61+
attention_chunk_size (Optional[int]): See [Chunked Attention] below.
7162
"""
7263
super().__init__()
7364
self.layer_idx = layer_idx
@@ -81,7 +72,6 @@ def __init__(
8172
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
8273
self.max_position_embeddings = max_position_embeddings
8374
self.pos_embd_params = pos_embd_params
84-
self.qk_norm_type = qk_norm_type
8575
self.dense_bias = dense_bias
8676
self.q_scaling = q_scaling
8777

@@ -169,14 +159,21 @@ def __init__(
169159
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
170160
[self.hidden_size])
171161

172-
# enable_rope_fusion: Whether to fuse RoPE into the attention OP.
162+
# Whether to fuse RoPE into the attention OP.
173163
# If true, RoPE will be applied in self.attn.forward.
174164
# If false, RoPE will be applied in self.apply_rope.
175-
self.enable_rope_fusion = attn_cls.support_fused_rope(
176-
) and self.qk_norm_type != QkNormType.post_rope
165+
self.rope_fusion = rope_fusion
166+
if self.rope_fusion and not attn_cls.support_fused_rope():
167+
logger.warning(
168+
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion."
169+
)
170+
self.rope_fusion = False
171+
# If rope_fusion is not specified, enable if the attention backend supports it.
172+
if self.rope_fusion is None:
173+
self.rope_fusion = attn_cls.support_fused_rope()
177174

178175
self.rotary_emb = None
179-
if not self.enable_rope_fusion and self.pos_embd_params is not None:
176+
if not self.rope_fusion and self.pos_embd_params is not None:
180177
self.rotary_emb = RotaryEmbedding(
181178
self.pos_embd_params.rope,
182179
head_dim=self.head_dim,
@@ -189,8 +186,7 @@ def __init__(
189186
self.num_heads,
190187
self.head_dim,
191188
self.num_key_value_heads,
192-
pos_embd_params=self.pos_embd_params
193-
if self.enable_rope_fusion else None,
189+
pos_embd_params=self.pos_embd_params if self.rope_fusion else None,
194190
quant_config=self.quant_config,
195191
skip_create_weights_in_init=config.skip_create_weights_in_init,
196192
q_scaling=self.q_scaling,
@@ -263,7 +259,9 @@ def forward(
263259
if qkv_lora is not None:
264260
qkv = qkv + qkv_lora
265261

266-
q, k, v = self.apply_rope(qkv, position_ids)
262+
q, k, v = qkv, None, None
263+
264+
q, k, v = self.apply_rope(q, k, v, position_ids)
267265

268266
out_scale = None
269267
out_scale_sf = None
@@ -290,32 +288,25 @@ def forward(
290288
layer_idx=self.layer_idx)
291289
return attn_output
292290

293-
def apply_qk_norm(self, q, k):
294-
raise NotImplementedError(
295-
f"QK norm is not implemented for {self.__class__.__name__}."
296-
"Please override the `apply_qk_norm` method in the subclass.")
297-
298-
def apply_rope(self, qkv: torch.Tensor, position_ids: torch.Tensor):
291+
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
292+
v: Optional[torch.Tensor], position_ids: torch.Tensor):
299293
"""
300-
Apply RoPE to the query and key, possibly including QK norm.
294+
Apply RoPE to the query and key.
295+
Depending on the implementation, q, k, v could be either fused (q, k, v = concat(q, k, v), None, None) or unfused (none of q, k, v is None).
296+
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
297+
This method could be overridden in the subclass, in which extra functionalities such as q_norm/k_norm could be added.
301298
Args:
302-
qkv (torch.Tensor): The query, key, and value tensor.
299+
q (torch.Tensor): The query tensor.
300+
k (Optional[torch.Tensor]): The key tensor.
301+
v (Optional[torch.Tensor]): The value tensor.
303302
position_ids (torch.Tensor): The position IDs of each token for RoPE.
304303
Returns:
305304
tuple: A tuple of (q, k, v).
306-
This method could be overridden in the subclass, it is possible that k/v is None and q is the concatenated qkv tensor, up to the implementation.
307-
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
308305
"""
309-
q, k, v = qkv, None, None
310-
if self.qk_norm_type == QkNormType.pre_rope:
311-
q, k, v = self.split_qkv(q, k, v)
312-
q, k = self.apply_qk_norm(q, k)
313-
if not self.enable_rope_fusion and position_ids is not None:
314-
q, k, v = self.split_qkv(q, k, v)
306+
q, k, v = self.split_qkv(q, k, v)
307+
# If RoPE is fused into the attention OP, do not apply RoPE here.
308+
if not self.rope_fusion and position_ids is not None:
315309
q, k = self.rotary_emb(position_ids, [q, k])
316-
if self.qk_norm_type == QkNormType.post_rope:
317-
q, k = self.apply_qk_norm(q, k)
318-
319310
return q, k, v
320311

321312

@@ -600,14 +591,14 @@ def yarn_get_mscale(scale=1, mscale=1):
600591
self.aux_stream = aux_stream
601592
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
602593

603-
self.enable_rope_fusion = self.mha.support_fused_rope()
594+
self.rope_fusion = self.mha.support_fused_rope()
604595
self.support_fused_qkv = self.mha.support_fused_qkv()
605596
self.rotary_emb = RotaryEmbedding(
606597
pos_embd_params.rope,
607598
head_dim=self.qk_rope_head_dim,
608599
is_neox=pos_embd_params.is_neox,
609600
)
610-
self.apply_rotary_emb = not self.enable_rope_fusion
601+
self.apply_rotary_emb = not self.rope_fusion
611602

612603
if not config.skip_create_weights_in_init:
613604
self.create_weights()

0 commit comments

Comments
 (0)