@@ -166,52 +166,18 @@ def propose(
166166 attn_metadata .max_query_len = 1
167167 attn_metadata .query_start_loc = self .arange [:batch_size + 1 ]
168168 for _ in range (self .num_speculative_tokens - 1 ):
169- # Update the inputs.
170- # cast to int32 is crucial when eagle model is compiled.
171- # tensor.argmax() returns int64 by default.
172- input_ids = draft_token_ids_list [- 1 ].int ()
173- positions += 1
174-
175- # NOTE(woosuk): We should handle the case where the draft model
176- # generates tokens beyond the max model length. Since it is complex
177- # to remove such requests from the batch, we keep them in the batch
178- # but adjust the position ids and slot mappings to avoid the
179- # out-of-range access during the model execution. The draft tokens
180- # generated with this adjustment should be ignored.
181- exceeds_max_model_len = positions >= self .max_model_len
182- # Mask out the position ids that exceed the max model length.
183- # Otherwise, we may get out-of-range error in RoPE.
184- clamped_positions = torch .where (exceeds_max_model_len , 0 ,
185- positions )
169+
170+ self .advance_speculative_state (draft_token_ids_list [- 1 ], positions ,
171+ hidden_states , attn_metadata ,
172+ batch_size )
186173
187174 # Increment the sequence lengths.
188175 attn_metadata .max_seq_len += 1
189- attn_metadata .seq_lens += 1
190176 # Consider max model length.
191177 attn_metadata .max_seq_len = min (attn_metadata .max_seq_len ,
192178 self .max_model_len )
193- # For the requests that exceed the max model length, we set the
194- # sequence length to 1 to minimize their overheads in attention.
195- attn_metadata .seq_lens .masked_fill_ (exceeds_max_model_len , 1 )
196-
197- # Compute the slot mapping.
198- block_numbers = clamped_positions // self .block_size
199- block_ids = block_table .gather (dim = 1 ,
200- index = block_numbers .view (- 1 , 1 ))
201- block_ids = block_ids .view (- 1 )
202- attn_metadata .slot_mapping = (block_ids * self .block_size +
203- clamped_positions % self .block_size )
204- # Mask out the slot mappings that exceed the max model length.
205- # Otherwise, the KV cache will be inadvertently updated with the
206- # padding tokens.
207- attn_metadata .slot_mapping .masked_fill_ (exceeds_max_model_len ,
208- PADDING_SLOT_ID )
209179
210180 # copy inputs to buffer for cudagraph
211- self .input_ids [:batch_size ] = input_ids
212- self .positions [:batch_size ] = clamped_positions
213- self .hidden_states [:batch_size ] = hidden_states
214-
215181 # Run the model.
216182 with set_forward_context (attn_metadata ,
217183 self .vllm_config ,
@@ -233,6 +199,38 @@ def propose(
233199 draft_token_ids = torch .stack (draft_token_ids_list , dim = 1 )
234200 return draft_token_ids
235201
202+ def advance_speculative_state (self , draft_token_ids : torch .Tensor ,
203+ positions : torch .Tensor ,
204+ hidden_states : torch .Tensor ,
205+ attn_metadata : FlashAttentionMetadata ,
206+ batch_size : int ):
207+ grid = lambda meta : (triton .cdiv (batch_size , meta ['BLOCK_SIZE' ]), )
208+ attn_metadata .slot_mapping = torch .empty_like (positions )
209+ advance_state_kernel [grid ](
210+ # === Input tensors ===
211+ draft_token_ids ,
212+ positions ,
213+ hidden_states ,
214+
215+ # === Model input buffers to be updated ===
216+ self .input_ids [:batch_size ],
217+ self .positions [:batch_size ],
218+ self .hidden_states [:batch_size ],
219+
220+ # === Metadata tensors ===
221+ attn_metadata .seq_lens ,
222+ attn_metadata .block_table ,
223+ attn_metadata .slot_mapping ,
224+
225+ # === Scalar configuration ===
226+ self .max_model_len ,
227+ self .block_size ,
228+
229+ # === Execution control ===
230+ batch_size ,
231+ BLOCK_SIZE = 1024 ,
232+ PADDING_SLOT_ID = PADDING_SLOT_ID )
233+
236234 @staticmethod
237235 def prepare_inputs (
238236 # [batch_size + 1]
@@ -415,3 +413,82 @@ def prepare_input_kernel(
415413 index_start + offset ,
416414 mask = offset < num_tokens ,
417415 )
416+
417+
418+ @triton .jit
419+ def advance_state_kernel (
420+ draft_token_ids_ptr ,
421+ positions_ptr ,
422+ hidden_states_ptr ,
423+
424+ # === Model input buffers to be updated ===
425+ model_input_ids_ptr ,
426+ model_positions_ptr ,
427+ model_hidden_states_ptr ,
428+
429+ # === Metadata tensors ===
430+ seq_lens_ptr ,
431+ block_table_ptr ,
432+ slot_mapping_ptr ,
433+
434+ # === Scalar configuration ===
435+ model_max_len : int ,
436+ model_block_size : int ,
437+
438+ # === Execution control ===
439+ n_elements : int ,
440+ BLOCK_SIZE : tl .constexpr ,
441+ PADDING_SLOT_ID : tl .constexpr ,
442+ ):
443+ pid = tl .program_id (axis = 0 )
444+ block_start = pid * BLOCK_SIZE
445+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
446+ mask = offsets < n_elements
447+ draft_token_list_last = tl .load (draft_token_ids_ptr + offsets , mask = mask )
448+ position = tl .load (positions_ptr + offsets , mask = mask )
449+ seq_lens = tl .load (seq_lens_ptr + offsets , mask = mask )
450+ hidden_states = tl .load (hidden_states_ptr + offsets , mask = mask )
451+
452+ # Update the inputs.
453+ # cast to int32 is crucial when eagle model is compiled.
454+ # tensor.argmax() returns int64 by default.
455+ input_id = draft_token_list_last .cast (tl .int32 )
456+ position = position + 1
457+
458+ # NOTE(woosuk): We should handle the case where the draft model
459+ # generates tokens beyond the max model length. Since it is complex
460+ # to remove such requests from the batch, we keep them in the batch
461+ # but adjust the position ids and slot mappings to avoid the
462+ # out-of-range access during the model execution. The draft tokens
463+ # generated with this adjustment should be ignored.
464+ exceeds_max_model_len = position >= model_max_len
465+ # Mask out the position ids that exceed the max model length.
466+ # Otherwise, we may get out-of-range error in RoPE.
467+ clamped_position = tl .where (exceeds_max_model_len , 0 , position )
468+
469+ # For the requests that exceed the max model length, we set the
470+ # sequence length to 1 to minimize their overheads in attention.
471+ seq_lens += 1
472+ seq_lens = tl .where (exceeds_max_model_len , 1 , seq_lens )
473+
474+ block_numbers = clamped_position // model_block_size
475+ block_offsets = clamped_position % model_block_size
476+
477+ # Gather from block_table[0, block_numbers]
478+ block_ids = tl .load (block_table_ptr + block_numbers , mask = mask )
479+
480+ # Compute slot mapping
481+ slot_mapping = block_ids * model_block_size + block_offsets
482+
483+ # Mask out the slot mappings that exceed the max model length.
484+ # Otherwise, the KV cache will be inadvertently updated with the
485+ # padding tokens.
486+ slot_mapping = tl .where (exceeds_max_model_len , PADDING_SLOT_ID ,
487+ slot_mapping )
488+
489+ tl .store (model_input_ids_ptr + offsets , input_id , mask = mask )
490+ tl .store (positions_ptr + offsets , position , mask = mask )
491+ tl .store (model_positions_ptr + offsets , clamped_position , mask = mask )
492+ tl .store (seq_lens_ptr + offsets , seq_lens , mask = mask )
493+ tl .store (slot_mapping_ptr + offsets , slot_mapping , mask = mask )
494+ tl .store (model_hidden_states_ptr + offsets , hidden_states , mask = mask )
0 commit comments