Skip to content

Commit ea997fe

Browse files
committed
Optimize Gated Delta Relu for Qwen3-Next Prefill
1 parent 8f24442 commit ea997fe

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

vllm/model_executor/models/qwen3_next.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def torch_chunk_gated_delta_rule(
6868
value,
6969
g,
7070
beta,
71+
eye_constant,
7172
chunk_size=64,
7273
initial_state=None,
7374
output_final_state=True,
@@ -113,42 +114,44 @@ def torch_chunk_gated_delta_rule(
113114

114115
# chunk decay
115116
g = g.cumsum(dim=-1)
117+
g_exp = g.exp()
116118
decay_mask = ((g.unsqueeze(-1) -
117119
g.unsqueeze(-2)).tril().exp().float()).tril()
118120
attn = -((torch.matmul(k_beta.contiguous(),
119121
key.transpose(-1, -2).contiguous())) *
120122
decay_mask).masked_fill(mask, 0)
121123
for i in range(1, chunk_size):
122-
row = attn[..., i, :i].clone()
123-
sub = attn[..., :i, :i].clone()
124-
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
125-
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
124+
row = attn[..., i, :i].contiguous()
125+
sub = attn[..., :i, :]
126+
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)[..., :i]
127+
attn = attn + eye_constant
126128
value = attn @ v_beta
127-
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
129+
k_cumdecay = attn @ (k_beta * g_exp.unsqueeze(-1))
128130
last_recurrent_state = (torch.zeros(batch_size, num_heads, k_head_dim,
129131
v_head_dim).to(value) if initial_state
130132
is None else initial_state.to(value))
131133
core_attn_out = torch.zeros_like(value)
132-
mask = torch.triu(torch.ones(chunk_size,
134+
mask = torch.tril(torch.ones(chunk_size,
133135
chunk_size,
134136
dtype=torch.bool,
135137
device=query.device),
136-
diagonal=1)
138+
diagonal=0)
139+
mask = mask.view(1, 1, 1, chunk_size, chunk_size)
140+
attn = (query @ key.transpose(-1, -2)) * decay_mask * mask
141+
qg = query * g_exp[..., None]
142+
delta_g_exp = (g[:, :, :, -1, None] - g).exp()[..., None]
143+
k_term = (key * delta_g_exp)
137144

138145
# for each chunk
139146
for i in range(0, tot_len // chunk_size):
140147
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
141-
attn = (q_i @ k_i.transpose(-1, -2) *
142-
decay_mask[:, :, i]).masked_fill_(mask, 0)
143148
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
144149
v_new = v_i - v_prime
145-
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
146-
core_attn_out[:, :, i] = attn_inter + attn @ v_new
150+
attn_inter = qg[:,:,i] @ last_recurrent_state
151+
core_attn_out[:, :, i] = attn_inter + attn[:,:,i] @ v_new
147152
last_recurrent_state = (
148-
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
149-
(k_i *
150-
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
151-
-1, -2) @ v_new)
153+
last_recurrent_state * g_exp[:, :, i, -1, None, None] +
154+
k_term[:,:,i].transpose(-1, -2) @ v_new)
152155

153156
if not output_final_state:
154157
last_recurrent_state = None
@@ -456,6 +459,11 @@ def __init__(
456459
dtype=torch.float32,
457460
device=self.conv1d.weight.device)
458461

462+
self.chunk_size = 64
463+
self.eye_constant = torch.eye(self.chunk_size,
464+
dtype=torch.float32,
465+
device=self.conv1d.weight.device)
466+
459467
# time step projection (discretization)
460468
# instantiate once and copy inv_dt in init_weights of PretrainedModel
461469
self.dt_bias = nn.Parameter(
@@ -660,6 +668,8 @@ def forward(
660668
value_non_spec,
661669
g=g,
662670
beta=beta,
671+
eye_constant=self.eye_constant,
672+
chunk_size=self.chunk_size,
663673
initial_state=None,
664674
output_final_state=True,
665675
use_qk_l2norm_in_kernel=True,

0 commit comments

Comments
 (0)