6363from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
6464from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
6565 KVCacheSpec )
66- from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , DraftTokenIds ,
67- LogprobsTensors , ModelRunnerOutput )
66+ from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
67+ DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
6868from vllm .v1 .pool .metadata import PoolingMetadata
6969from vllm .v1 .sample .metadata import SamplingMetadata
7070from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
@@ -156,6 +156,53 @@ def graph_capture(device: torch.device):
156156 yield graph_capture_context
157157
158158
159+ # Wrapper for ModelRunnerOutput to support overlapped execution.
160+ class AsyncNPUModelRunnerOutput (AsyncModelRunnerOutput ):
161+
162+ def __init__ (
163+ self ,
164+ model_runner_output : ModelRunnerOutput ,
165+ sampled_token_ids : torch .Tensor ,
166+ invalid_req_indices : list [int ],
167+ async_output_copy_stream : torch .npu .Stream ,
168+ ):
169+ self ._model_runner_output = model_runner_output
170+ self ._invalid_req_indices = invalid_req_indices
171+
172+ # Event on the copy stream so we can synchronize the non-blocking copy.
173+ self ._async_copy_ready_event = torch .npu .Event ()
174+
175+ # Keep a reference to the device tensor to avoid it being
176+ # deallocated until we finish copying it to the host.
177+ self ._sampled_token_ids = sampled_token_ids
178+
179+ # Initiate the copy on a separate stream, but do not synchronize it.
180+ default_stream = torch .npu .current_stream ()
181+ with torch .npu .stream (async_output_copy_stream ):
182+ async_output_copy_stream .wait_stream (default_stream )
183+ self ._sampled_token_ids_cpu = self ._sampled_token_ids .to (
184+ 'cpu' , non_blocking = True )
185+ self ._async_copy_ready_event .record ()
186+
187+ def get_output (self ) -> ModelRunnerOutput :
188+ """Copy the device tensors to the host and return a ModelRunnerOutput.
189+
190+ This function blocks until the copy is finished.
191+ """
192+ self ._async_copy_ready_event .synchronize ()
193+
194+ # Release the device tensor once the copy has completed
195+ del self ._sampled_token_ids
196+
197+ valid_sampled_token_ids = self ._sampled_token_ids_cpu .tolist ()
198+ for i in self ._invalid_req_indices :
199+ valid_sampled_token_ids [i ].clear ()
200+
201+ output = self ._model_runner_output
202+ output .sampled_token_ids = valid_sampled_token_ids
203+ return output
204+
205+
159206class NPUModelRunner (LoRAModelRunnerMixin ):
160207
161208 def __init__ (self , vllm_config : VllmConfig , device : torch .device ):
@@ -358,6 +405,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
358405 device = self .device ,
359406 )
360407
408+ self .use_async_scheduling = self .scheduler_config .async_scheduling
409+ self .async_output_copy_stream = torch .npu .Stream () if \
410+ self .use_async_scheduling else None
411+
361412 def _use_aclgraph (self ) -> bool :
362413 return self .compilation_config .cudagraph_mode != CUDAGraphMode .NONE and self .compilation_config .level == CompilationLevel .PIECEWISE and not self .model_config .enforce_eager
363414
@@ -845,6 +896,76 @@ def _get_cumsum_and_arange(
845896
846897 return cu_num_tokens , arange
847898
899+ def _prepare_input_ids (self , total_num_scheduled_tokens : int ,
900+ cu_num_tokens : np .ndarray ) -> None :
901+ """Prepare the input IDs for the current batch.
902+
903+ Carefully handles the `prev_sampled_token_ids` which can be cached
904+ from the previous engine iteration, in which case those tokens on the
905+ NPU need to be copied into the corresponding slots into input_ids."""
906+
907+ if self .input_batch .prev_sampled_token_ids is None :
908+ # Normal scheduling case
909+ self .input_ids [:total_num_scheduled_tokens ].copy_ (
910+ self .input_ids_cpu [:total_num_scheduled_tokens ],
911+ non_blocking = True )
912+ return
913+
914+ # Async scheduling case, where some decode requests from the previous
915+ # iteration won't have entries in input_ids_cpu and need to be copied
916+ # on the NPU from prev_sampled_token_ids.
917+ prev_req_id_to_index = self .input_batch .prev_req_id_to_index
918+ assert prev_req_id_to_index is not None
919+ flattened_indices = []
920+ prev_common_req_indices = []
921+ indices_match = True
922+ max_flattened_index = - 1
923+ for req_id , cur_index in self .input_batch .req_id_to_index .items ():
924+ if (prev_index := prev_req_id_to_index .get (req_id )) is not None :
925+ prev_common_req_indices .append (prev_index )
926+ # We need to compute the flattened input_ids index of the
927+ # last token in each common request.
928+ flattened_index = cu_num_tokens [cur_index ].item () - 1
929+ flattened_indices .append (flattened_index )
930+ indices_match &= (prev_index == flattened_index )
931+ max_flattened_index = max (max_flattened_index , flattened_index )
932+ num_commmon_tokens = len (flattened_indices )
933+ if num_commmon_tokens < total_num_scheduled_tokens :
934+ # If not all requests are decodes from the last iteration,
935+ # We need to copy the input_ids_cpu to the NPU first.
936+ self .input_ids [:total_num_scheduled_tokens ].copy_ (
937+ self .input_ids_cpu [:total_num_scheduled_tokens ],
938+ non_blocking = True )
939+ if num_commmon_tokens == 0 :
940+ # No requests in common with the previous iteration
941+ # So input_ids_cpu will have all the input ids.
942+ return
943+ if indices_match and max_flattened_index == (num_commmon_tokens - 1 ):
944+ # Common-case optimization: the batch is unchanged
945+ # and no reordering happened.
946+ # The indices are both the same permutation of 0..N-1 so
947+ # we can copy directly using a single slice.
948+ self .input_ids [:num_commmon_tokens ].copy_ (
949+ self .input_batch .prev_sampled_token_ids [:num_commmon_tokens ,
950+ 0 ],
951+ non_blocking = True )
952+ return
953+ # Upload the index tensors asynchronously
954+ # so the scatter can be non-blocking.
955+ input_ids_index_tensor = torch .tensor (flattened_indices ,
956+ dtype = torch .int64 ,
957+ pin_memory = self .pin_memory ).to (
958+ self .device ,
959+ non_blocking = True )
960+ prev_common_req_indices_tensor = torch .tensor (
961+ prev_common_req_indices ,
962+ dtype = torch .int64 ,
963+ pin_memory = self .pin_memory ).to (self .device , non_blocking = True )
964+ self .input_ids .scatter_ (dim = 0 ,
965+ index = input_ids_index_tensor ,
966+ src = self .input_batch .prev_sampled_token_ids [
967+ prev_common_req_indices_tensor , 0 ])
968+
848969 def _prepare_inputs (
849970 self ,
850971 scheduler_output : "SchedulerOutput" ,
@@ -1033,6 +1154,16 @@ def _prepare_inputs(
10331154 if self .vllm_config .model_config .use_mla :
10341155 attn_metadata .num_input_tokens = num_input_tokens
10351156
1157+ # Prepare input_ids
1158+ token_indices = (positions_np +
1159+ req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1160+ torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1161+ 0 ,
1162+ torch .from_numpy (token_indices ),
1163+ out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1164+ # Copy the tensors to the NPU.
1165+ self ._prepare_input_ids (total_num_scheduled_tokens , cu_num_tokens )
1166+
10361167 # _prepare_inputs may reorder the batch, so we must gather
10371168 # multi-modal outputs after that to ensure the correct order
10381169 if self .is_multimodal_model :
@@ -1382,11 +1513,11 @@ def _select_moe_comm_method(self, num_tokens: int) -> str:
13821513 2. If expert parallel is enabled, we need to consider the soc version and the
13831514 number of tokens. This is based on the observation that all-gather is more
13841515 efficient than all-to-all when running on A2.
1385-
1516+
13861517 a. For A2, we choose from MC2 and all-gather.
1387-
1518+
13881519 b. For A3, we choose from MC2 and all-to-all.
1389-
1520+
13901521 In both cases, we use MC2 when the number of tokens is smaller than
13911522 a its capacity threshold.
13921523
@@ -1424,7 +1555,7 @@ def execute_model(
14241555 self ,
14251556 scheduler_output : "SchedulerOutput" ,
14261557 intermediate_tensors : Optional [IntermediateTensors ] = None ,
1427- ) -> Union [ModelRunnerOutput , torch . Tensor ]:
1558+ ) -> Union [ModelRunnerOutput , AsyncModelRunnerOutput , IntermediateTensors ]:
14281559 with ProfileExecuteDuration ().capture_async ("prepare input" ):
14291560 self ._update_states (scheduler_output )
14301561 if not scheduler_output .total_num_scheduled_tokens :
@@ -1580,6 +1711,12 @@ def execute_model(
15801711 generator .set_offset (generator .get_offset () - 4 )
15811712 discard_sampled_tokens_req_indices .append (i )
15821713
1714+ # Copy some objects so they don't get modified after returning.
1715+ # This is important when using async scheduling.
1716+ req_ids_output_copy = self .input_batch .req_ids .copy ()
1717+ req_id_to_index_output_copy = \
1718+ self .input_batch .req_id_to_index .copy ()
1719+
15831720 # NOTE: NPU -> CPU Sync happens here.
15841721 # Move as many CPU operations as possible before this sync point.
15851722 logprobs_tensors = sampler_output .logprobs_tensors
@@ -1592,27 +1729,52 @@ def execute_model(
15921729 scheduler_output ,
15931730 )
15941731
1595- # Get the valid generated tokens.
1732+ num_sampled_tokens = sampler_output . sampled_token_ids . shape [ 0 ]
15961733 sampled_token_ids = sampler_output .sampled_token_ids
1597- max_gen_len = sampled_token_ids .shape [- 1 ]
1598- if max_gen_len == 1 :
1599- # No spec decode tokens.
1600- valid_sampled_token_ids = sampled_token_ids .tolist ()
1734+ if not self .use_async_scheduling :
1735+ # Get the valid generated tokens.
1736+ max_gen_len = sampled_token_ids .shape [- 1 ]
1737+ if max_gen_len == 1 :
1738+ # No spec decode tokens.
1739+ valid_sampled_token_ids = sampled_token_ids .tolist ()
1740+ else :
1741+ # Includes spec decode tokens.
1742+ valid_sampled_token_ids = self .rejection_sampler .parse_output (
1743+ sampled_token_ids ,
1744+ self .input_batch .vocab_size ,
1745+ )
1746+ # Mask out the sampled tokens that should not be sampled.
1747+ for i in discard_sampled_tokens_req_indices :
1748+ valid_sampled_token_ids [i ].clear ()
16011749 else :
1602- # Includes spec decode tokens.
1603- valid_sampled_token_ids = self .rejection_sampler .parse_output (
1604- sampled_token_ids ,
1605- self .input_batch .vocab_size ,
1606- )
1607-
1608- for i in discard_sampled_tokens_req_indices :
1609- valid_sampled_token_ids [i ].clear ()
1610- # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
1750+ valid_sampled_token_ids = []
1751+ invalid_req_indices = list (discard_sampled_tokens_req_indices )
1752+ invalid_req_indices_set = set (invalid_req_indices )
1753+ assert sampled_token_ids .shape [- 1 ] == 1
1754+
1755+ # Cache the sampled tokens on the NPU and avoid CPU sync.
1756+ # These will be copied into input_ids in the next step
1757+ # when preparing inputs.
1758+ self .input_batch .prev_sampled_token_ids = \
1759+ sampled_token_ids
1760+ self .input_batch .prev_sampled_token_ids_invalid_indices = \
1761+ invalid_req_indices_set
1762+ self .input_batch .prev_req_id_to_index = {
1763+ req_id : i
1764+ for i , req_id in enumerate (self .input_batch .req_ids )
1765+ if i not in invalid_req_indices_set
1766+ }
1767+ # Cache the sampled tokens in the model runner, so that the scheduler
16111768 # doesn't need to send them back.
16121769 # NOTE(woosuk): As an exception, when using PP, the scheduler sends
16131770 # the sampled tokens back, because there's no direct communication
16141771 # between the first-stage worker and the last-stage worker.
1615- for req_idx , sampled_ids in enumerate (valid_sampled_token_ids ):
1772+ for req_idx in range (num_sampled_tokens ):
1773+ if self .use_async_scheduling :
1774+ sampled_ids = [- 1 ] * 1 if \
1775+ req_idx not in invalid_req_indices_set else None
1776+ else :
1777+ sampled_ids = valid_sampled_token_ids [req_idx ]
16161778 if not sampled_ids :
16171779 continue
16181780
@@ -1650,8 +1812,8 @@ def execute_model(
16501812 extra_args = ({"kv_connector_output" : kv_connector_output })
16511813
16521814 model_runner_output = ModelRunnerOutput (
1653- req_ids = self . input_batch . req_ids ,
1654- req_id_to_index = self . input_batch . req_id_to_index ,
1815+ req_ids = req_ids_output_copy ,
1816+ req_id_to_index = req_id_to_index_output_copy ,
16551817 sampled_token_ids = valid_sampled_token_ids ,
16561818 logprobs = logprobs_lists ,
16571819 prompt_logprobs_dict = prompt_logprobs_dict ,
@@ -1669,7 +1831,15 @@ def execute_model(
16691831 logger .info ("Profile execute duration [%s]:%s" , captured_name ,
16701832 " " .join (dr_str ))
16711833
1672- return model_runner_output
1834+ if not self .use_async_scheduling :
1835+ return model_runner_output
1836+
1837+ return AsyncNPUModelRunnerOutput (
1838+ model_runner_output = model_runner_output ,
1839+ sampled_token_ids = sampled_token_ids ,
1840+ invalid_req_indices = invalid_req_indices ,
1841+ async_output_copy_stream = self .async_output_copy_stream ,
1842+ )
16731843
16741844 def take_draft_token_ids (self ) -> Optional [DraftTokenIds ]:
16751845 if self ._draft_token_ids is None :
0 commit comments