Skip to content

Commit 61c0b12

Browse files
committed
triton kernel fusion for EAGLE
Signed-off-by: Leo Tian <[email protected]>
1 parent e3f3aee commit 61c0b12

File tree

1 file changed

+115
-38
lines changed

1 file changed

+115
-38
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 115 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -166,52 +166,18 @@ def propose(
166166
attn_metadata.max_query_len = 1
167167
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
168168
for _ in range(self.num_speculative_tokens - 1):
169-
# Update the inputs.
170-
# cast to int32 is crucial when eagle model is compiled.
171-
# tensor.argmax() returns int64 by default.
172-
input_ids = draft_token_ids_list[-1].int()
173-
positions += 1
174-
175-
# NOTE(woosuk): We should handle the case where the draft model
176-
# generates tokens beyond the max model length. Since it is complex
177-
# to remove such requests from the batch, we keep them in the batch
178-
# but adjust the position ids and slot mappings to avoid the
179-
# out-of-range access during the model execution. The draft tokens
180-
# generated with this adjustment should be ignored.
181-
exceeds_max_model_len = positions >= self.max_model_len
182-
# Mask out the position ids that exceed the max model length.
183-
# Otherwise, we may get out-of-range error in RoPE.
184-
clamped_positions = torch.where(exceeds_max_model_len, 0,
185-
positions)
169+
170+
self.advance_speculative_state(draft_token_ids_list[-1], positions,
171+
hidden_states, attn_metadata,
172+
batch_size)
186173

187174
# Increment the sequence lengths.
188175
attn_metadata.max_seq_len += 1
189-
attn_metadata.seq_lens += 1
190176
# Consider max model length.
191177
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
192178
self.max_model_len)
193-
# For the requests that exceed the max model length, we set the
194-
# sequence length to 1 to minimize their overheads in attention.
195-
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
196-
197-
# Compute the slot mapping.
198-
block_numbers = clamped_positions // self.block_size
199-
block_ids = block_table.gather(dim=1,
200-
index=block_numbers.view(-1, 1))
201-
block_ids = block_ids.view(-1)
202-
attn_metadata.slot_mapping = (block_ids * self.block_size +
203-
clamped_positions % self.block_size)
204-
# Mask out the slot mappings that exceed the max model length.
205-
# Otherwise, the KV cache will be inadvertently updated with the
206-
# padding tokens.
207-
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
208-
PADDING_SLOT_ID)
209179

210180
# copy inputs to buffer for cudagraph
211-
self.input_ids[:batch_size] = input_ids
212-
self.positions[:batch_size] = clamped_positions
213-
self.hidden_states[:batch_size] = hidden_states
214-
215181
# Run the model.
216182
with set_forward_context(attn_metadata,
217183
self.vllm_config,
@@ -233,6 +199,38 @@ def propose(
233199
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
234200
return draft_token_ids
235201

202+
def advance_speculative_state(self, draft_token_ids: torch.Tensor,
203+
positions: torch.Tensor,
204+
hidden_states: torch.Tensor,
205+
attn_metadata: FlashAttentionMetadata,
206+
batch_size: int):
207+
grid = lambda meta: (triton.cdiv(batch_size, meta['BLOCK_SIZE']), )
208+
attn_metadata.slot_mapping = torch.empty_like(positions)
209+
advance_state_kernel[grid](
210+
# === Input tensors ===
211+
draft_token_ids,
212+
positions,
213+
hidden_states,
214+
215+
# === Model input buffers to be updated ===
216+
self.input_ids[:batch_size],
217+
self.positions[:batch_size],
218+
self.hidden_states[:batch_size],
219+
220+
# === Metadata tensors ===
221+
attn_metadata.seq_lens,
222+
attn_metadata.block_table,
223+
attn_metadata.slot_mapping,
224+
225+
# === Scalar configuration ===
226+
self.max_model_len,
227+
self.block_size,
228+
229+
# === Execution control ===
230+
batch_size,
231+
BLOCK_SIZE=1024,
232+
PADDING_SLOT_ID=PADDING_SLOT_ID)
233+
236234
@staticmethod
237235
def prepare_inputs(
238236
# [batch_size + 1]
@@ -415,3 +413,82 @@ def prepare_input_kernel(
415413
index_start + offset,
416414
mask=offset < num_tokens,
417415
)
416+
417+
418+
@triton.jit
419+
def advance_state_kernel(
420+
draft_token_ids_ptr,
421+
positions_ptr,
422+
hidden_states_ptr,
423+
424+
# === Model input buffers to be updated ===
425+
model_input_ids_ptr,
426+
model_positions_ptr,
427+
model_hidden_states_ptr,
428+
429+
# === Metadata tensors ===
430+
seq_lens_ptr,
431+
block_table_ptr,
432+
slot_mapping_ptr,
433+
434+
# === Scalar configuration ===
435+
model_max_len: int,
436+
model_block_size: int,
437+
438+
# === Execution control ===
439+
n_elements: int,
440+
BLOCK_SIZE: tl.constexpr,
441+
PADDING_SLOT_ID: tl.constexpr,
442+
):
443+
pid = tl.program_id(axis=0)
444+
block_start = pid * BLOCK_SIZE
445+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
446+
mask = offsets < n_elements
447+
draft_token_list_last = tl.load(draft_token_ids_ptr + offsets, mask=mask)
448+
position = tl.load(positions_ptr + offsets, mask=mask)
449+
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
450+
hidden_states = tl.load(hidden_states_ptr + offsets, mask=mask)
451+
452+
# Update the inputs.
453+
# cast to int32 is crucial when eagle model is compiled.
454+
# tensor.argmax() returns int64 by default.
455+
input_id = draft_token_list_last.cast(tl.int32)
456+
position = position + 1
457+
458+
# NOTE(woosuk): We should handle the case where the draft model
459+
# generates tokens beyond the max model length. Since it is complex
460+
# to remove such requests from the batch, we keep them in the batch
461+
# but adjust the position ids and slot mappings to avoid the
462+
# out-of-range access during the model execution. The draft tokens
463+
# generated with this adjustment should be ignored.
464+
exceeds_max_model_len = position >= model_max_len
465+
# Mask out the position ids that exceed the max model length.
466+
# Otherwise, we may get out-of-range error in RoPE.
467+
clamped_position = tl.where(exceeds_max_model_len, 0, position)
468+
469+
# For the requests that exceed the max model length, we set the
470+
# sequence length to 1 to minimize their overheads in attention.
471+
seq_lens += 1
472+
seq_lens = tl.where(exceeds_max_model_len, 1, seq_lens)
473+
474+
block_numbers = clamped_position // model_block_size
475+
block_offsets = clamped_position % model_block_size
476+
477+
# Gather from block_table[0, block_numbers]
478+
block_ids = tl.load(block_table_ptr + block_numbers, mask=mask)
479+
480+
# Compute slot mapping
481+
slot_mapping = block_ids * model_block_size + block_offsets
482+
483+
# Mask out the slot mappings that exceed the max model length.
484+
# Otherwise, the KV cache will be inadvertently updated with the
485+
# padding tokens.
486+
slot_mapping = tl.where(exceeds_max_model_len, PADDING_SLOT_ID,
487+
slot_mapping)
488+
489+
tl.store(model_input_ids_ptr + offsets, input_id, mask=mask)
490+
tl.store(positions_ptr + offsets, position, mask=mask)
491+
tl.store(model_positions_ptr + offsets, clamped_position, mask=mask)
492+
tl.store(seq_lens_ptr + offsets, seq_lens, mask=mask)
493+
tl.store(slot_mapping_ptr + offsets, slot_mapping, mask=mask)
494+
tl.store(model_hidden_states_ptr + offsets, hidden_states, mask=mask)

0 commit comments

Comments
 (0)