Skip to content

Commit bf6a3d0

Browse files
authored
[Misc] Add more scoping for improved trace (#28329)
Signed-off-by: Wei Wei <[email protected]>
1 parent 40d3326 commit bf6a3d0

File tree

4 files changed

+192
-148
lines changed

4 files changed

+192
-148
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from vllm.v1.request import Request, RequestStatus
3939
from vllm.v1.spec_decode.metrics import SpecDecodingStats
4040
from vllm.v1.structured_output import StructuredOutputManager
41+
from vllm.v1.utils import record_function_or_nullcontext
4142

4243
logger = init_logger(__name__)
4344

@@ -259,49 +260,52 @@ def schedule(self) -> SchedulerOutput:
259260
continue
260261

261262
# Schedule newly needed KV blocks for the request.
262-
while True:
263-
new_blocks = self.kv_cache_manager.allocate_slots(
264-
request,
265-
num_new_tokens,
266-
num_lookahead_tokens=self.num_lookahead_tokens,
267-
)
268-
269-
if new_blocks is not None:
270-
# The request can be scheduled.
271-
break
272-
273-
# The request cannot be scheduled.
274-
# Preempt the lowest-priority request.
275-
if self.policy == SchedulingPolicy.PRIORITY:
276-
preempted_req = max(
277-
self.running,
278-
key=lambda r: (r.priority, r.arrival_time),
263+
with record_function_or_nullcontext("schedule: allocate_slots"):
264+
while True:
265+
new_blocks = self.kv_cache_manager.allocate_slots(
266+
request,
267+
num_new_tokens,
268+
num_lookahead_tokens=self.num_lookahead_tokens,
279269
)
280-
self.running.remove(preempted_req)
281-
if preempted_req in scheduled_running_reqs:
282-
scheduled_running_reqs.remove(preempted_req)
283-
token_budget += num_scheduled_tokens[preempted_req.request_id]
284-
req_to_new_blocks.pop(preempted_req.request_id)
285-
num_scheduled_tokens.pop(preempted_req.request_id)
286-
req_index -= 1
287-
else:
288-
preempted_req = self.running.pop()
289270

290-
self.kv_cache_manager.free(preempted_req)
291-
self.encoder_cache_manager.free(preempted_req)
292-
preempted_req.status = RequestStatus.PREEMPTED
293-
preempted_req.num_computed_tokens = 0
294-
preempted_req.num_preemptions += 1
295-
if self.log_stats:
296-
preempted_req.record_event(
297-
EngineCoreEventType.PREEMPTED, scheduled_timestamp
298-
)
271+
if new_blocks is not None:
272+
# The request can be scheduled.
273+
break
299274

300-
self.waiting.prepend_request(preempted_req)
301-
preempted_reqs.append(preempted_req)
302-
if preempted_req == request:
303-
# No more request to preempt. Cannot schedule this request.
304-
break
275+
# The request cannot be scheduled.
276+
# Preempt the lowest-priority request.
277+
if self.policy == SchedulingPolicy.PRIORITY:
278+
preempted_req = max(
279+
self.running,
280+
key=lambda r: (r.priority, r.arrival_time),
281+
)
282+
self.running.remove(preempted_req)
283+
if preempted_req in scheduled_running_reqs:
284+
scheduled_running_reqs.remove(preempted_req)
285+
token_budget += num_scheduled_tokens[
286+
preempted_req.request_id
287+
]
288+
req_to_new_blocks.pop(preempted_req.request_id)
289+
num_scheduled_tokens.pop(preempted_req.request_id)
290+
req_index -= 1
291+
else:
292+
preempted_req = self.running.pop()
293+
294+
self.kv_cache_manager.free(preempted_req)
295+
self.encoder_cache_manager.free(preempted_req)
296+
preempted_req.status = RequestStatus.PREEMPTED
297+
preempted_req.num_computed_tokens = 0
298+
preempted_req.num_preemptions += 1
299+
if self.log_stats:
300+
preempted_req.record_event(
301+
EngineCoreEventType.PREEMPTED, scheduled_timestamp
302+
)
303+
304+
self.waiting.prepend_request(preempted_req)
305+
preempted_reqs.append(preempted_req)
306+
if preempted_req == request:
307+
# No more request to preempt. Cannot schedule this request.
308+
break
305309

306310
if new_blocks is None:
307311
# Cannot schedule this request.
@@ -599,13 +603,14 @@ def schedule(self) -> SchedulerOutput:
599603
# Get the longest common prefix among all requests in the running queue.
600604
# This can be potentially used for cascade attention.
601605
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
602-
if self.running:
603-
any_request = self.running[0]
604-
num_common_prefix_blocks = (
605-
self.kv_cache_manager.get_num_common_prefix_blocks(
606-
any_request.request_id
606+
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
607+
if self.running:
608+
any_request = self.running[0]
609+
num_common_prefix_blocks = (
610+
self.kv_cache_manager.get_num_common_prefix_blocks(
611+
any_request.request_id
612+
)
607613
)
608-
)
609614

610615
# Construct the scheduler output.
611616
new_reqs_data = [
@@ -614,13 +619,14 @@ def schedule(self) -> SchedulerOutput:
614619
)
615620
for req in scheduled_new_reqs
616621
]
617-
cached_reqs_data = self._make_cached_request_data(
618-
scheduled_running_reqs,
619-
scheduled_resumed_reqs,
620-
num_scheduled_tokens,
621-
scheduled_spec_decode_tokens,
622-
req_to_new_blocks,
623-
)
622+
with record_function_or_nullcontext("schedule: make_cached_request_data"):
623+
cached_reqs_data = self._make_cached_request_data(
624+
scheduled_running_reqs,
625+
scheduled_resumed_reqs,
626+
num_scheduled_tokens,
627+
scheduled_spec_decode_tokens,
628+
req_to_new_blocks,
629+
)
624630

625631
# Record the request ids that were scheduled in this step.
626632
self.prev_step_scheduled_req_ids.clear()
@@ -649,8 +655,8 @@ def schedule(self) -> SchedulerOutput:
649655
if self.connector is not None:
650656
meta = self.connector.build_connector_meta(scheduler_output)
651657
scheduler_output.kv_connector_metadata = meta
652-
653-
self._update_after_schedule(scheduler_output)
658+
with record_function_or_nullcontext("schedule: update_after_schedule"):
659+
self._update_after_schedule(scheduler_output)
654660
return scheduler_output
655661

656662
def _update_after_schedule(

vllm/v1/engine/core.py

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from vllm.v1.request import Request, RequestStatus
6262
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
6363
from vllm.v1.structured_output import StructuredOutputManager
64+
from vllm.v1.utils import record_function_or_nullcontext
6465
from vllm.version import __version__ as VLLM_VERSION
6566

6667
logger = init_logger(__name__)
@@ -315,17 +316,21 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
315316
# or finished and not yet removed from the batch.
316317
if not self.scheduler.has_requests():
317318
return {}, False
318-
scheduler_output = self.scheduler.schedule()
319-
future = self.model_executor.execute_model(scheduler_output, non_block=True)
320-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
321-
with self.log_error_detail(scheduler_output):
322-
model_output = future.result()
323-
if model_output is None:
324-
model_output = self.model_executor.sample_tokens(grammar_output)
325-
326-
engine_core_outputs = self.scheduler.update_from_output(
327-
scheduler_output, model_output
328-
)
319+
with record_function_or_nullcontext("core step: schedule"):
320+
scheduler_output = self.scheduler.schedule()
321+
322+
with record_function_or_nullcontext("core step: execute_model"):
323+
future = self.model_executor.execute_model(scheduler_output, non_block=True)
324+
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
325+
with self.log_error_detail(scheduler_output):
326+
model_output = future.result()
327+
if model_output is None:
328+
model_output = self.model_executor.sample_tokens(grammar_output)
329+
330+
with record_function_or_nullcontext("core step: update_from_output"):
331+
engine_core_outputs = self.scheduler.update_from_output(
332+
scheduler_output, model_output
333+
)
329334

330335
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
331336

@@ -363,32 +368,49 @@ def step_with_batch_queue(
363368
model_executed = False
364369
deferred_scheduler_output = None
365370
if self.scheduler.has_requests():
366-
scheduler_output = self.scheduler.schedule()
367-
exec_future = self.model_executor.execute_model(
368-
scheduler_output, non_block=True
369-
)
371+
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
372+
scheduler_output = self.scheduler.schedule()
373+
with record_function_or_nullcontext(
374+
"core step_with_batch_queue: execute_model"
375+
):
376+
exec_future = self.model_executor.execute_model(
377+
scheduler_output, non_block=True
378+
)
370379
model_executed = scheduler_output.total_num_scheduled_tokens > 0
371380

372381
if scheduler_output.pending_structured_output_tokens:
373-
# We need to defer sampling until we have processed the model output
374-
# from the prior step.
375-
deferred_scheduler_output = scheduler_output
376-
# Block-wait for execute to return (continues running async on the GPU).
377-
with self.log_error_detail(scheduler_output):
378-
exec_result = exec_future.result()
379-
assert exec_result is None
382+
with record_function_or_nullcontext(
383+
"core step_with_batch_queue: pending_structured_output_tokens"
384+
):
385+
# We need to defer sampling until we have processed the model output
386+
# from the prior step.
387+
deferred_scheduler_output = scheduler_output
388+
# Block-wait for execute to return
389+
# (continues running async on the GPU).
390+
with self.log_error_detail(scheduler_output):
391+
exec_result = exec_future.result()
392+
assert exec_result is None
380393
else:
381-
# We aren't waiting for any tokens, get any grammar output immediately.
382-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
394+
with record_function_or_nullcontext(
395+
"core step_with_batch_queue: get_grammar_bitmask"
396+
):
397+
# We aren't waiting for any tokens, get any grammar
398+
# output immediately.
399+
grammar_output = self.scheduler.get_grammar_bitmask(
400+
scheduler_output
401+
)
383402
# Block-wait for execute to return (continues running async on the GPU).
384403
with self.log_error_detail(scheduler_output):
385404
exec_result = exec_future.result()
386405

387406
if exec_result is None:
388-
# Call sample tokens.
389-
future = self.model_executor.sample_tokens(
390-
grammar_output, non_block=True
391-
)
407+
with record_function_or_nullcontext(
408+
"core step_with_batch_queue: sample_tokens"
409+
):
410+
# Call sample tokens.
411+
future = self.model_executor.sample_tokens(
412+
grammar_output, non_block=True
413+
)
392414
else:
393415
# No sampling required (e.g. all requests finished).
394416
future = cast(Future[ModelRunnerOutput], exec_future)
@@ -408,27 +430,34 @@ def step_with_batch_queue(
408430
# only be called when the scheduler contains requests or the queue
409431
# is non-empty.
410432
return None, False
411-
412-
# Block until the next result is available.
413-
future, scheduler_output = batch_queue.pop()
414-
with self.log_error_detail(scheduler_output):
415-
model_output = future.result()
416-
417-
engine_core_outputs = self.scheduler.update_from_output(
418-
scheduler_output, model_output
419-
)
433+
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
434+
# Block until the next result is available.
435+
future, scheduler_output = batch_queue.pop()
436+
with self.log_error_detail(scheduler_output):
437+
model_output = future.result()
438+
with record_function_or_nullcontext(
439+
"core step_with_batch_queue: update_from_output"
440+
):
441+
engine_core_outputs = self.scheduler.update_from_output(
442+
scheduler_output, model_output
443+
)
420444

421445
# NOTE(nick): We can either handle the deferred tasks here or save
422446
# in a field and do it immediately once step_with_batch_queue is
423447
# re-called. The latter slightly favors TTFT over TPOT/throughput.
424448
if deferred_scheduler_output:
425-
# We now have the tokens needed to compute the bitmask for the
426-
# deferred request. Get the bitmask and call sample tokens.
427-
grammar_output = self.scheduler.get_grammar_bitmask(
428-
deferred_scheduler_output
429-
)
430-
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
431-
batch_queue.appendleft((future, deferred_scheduler_output))
449+
with record_function_or_nullcontext(
450+
"core step_with_batch_queue: deferred_scheduler_output"
451+
):
452+
# We now have the tokens needed to compute the bitmask for the
453+
# deferred request. Get the bitmask and call sample tokens.
454+
grammar_output = self.scheduler.get_grammar_bitmask(
455+
deferred_scheduler_output
456+
)
457+
future = self.model_executor.sample_tokens(
458+
grammar_output, non_block=True
459+
)
460+
batch_queue.appendleft((future, deferred_scheduler_output))
432461

433462
return engine_core_outputs, model_executed
434463

vllm/v1/engine/llm_engine.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
3737
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
3838
from vllm.v1.metrics.stats import IterationStats
39+
from vllm.v1.utils import record_function_or_nullcontext
3940
from vllm.v1.worker.worker_base import WorkerBase
4041

4142
logger = init_logger(__name__)
@@ -280,28 +281,32 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]:
280281
return []
281282

282283
# 1) Get EngineCoreOutput from the EngineCore.
283-
outputs = self.engine_core.get_output()
284+
with record_function_or_nullcontext("llm_genine step: get_output"):
285+
outputs = self.engine_core.get_output()
284286

285287
# 2) Process EngineCoreOutputs.
286-
iteration_stats = IterationStats() if self.log_stats else None
287-
processed_outputs = self.output_processor.process_outputs(
288-
outputs.outputs,
289-
engine_core_timestamp=outputs.timestamp,
290-
iteration_stats=iteration_stats,
291-
)
292-
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
288+
with record_function_or_nullcontext("llm_genine step: process_outputs"):
289+
iteration_stats = IterationStats() if self.log_stats else None
290+
processed_outputs = self.output_processor.process_outputs(
291+
outputs.outputs,
292+
engine_core_timestamp=outputs.timestamp,
293+
iteration_stats=iteration_stats,
294+
)
295+
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
293296

294297
# 3) Abort any reqs that finished due to stop strings.
295-
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
298+
with record_function_or_nullcontext("llm_genine step: abort_requests"):
299+
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
296300

297301
# 4) Record stats
298-
if self.logger_manager is not None and outputs.scheduler_stats is not None:
299-
self.logger_manager.record(
300-
scheduler_stats=outputs.scheduler_stats,
301-
iteration_stats=iteration_stats,
302-
mm_cache_stats=self.processor.stat_mm_cache(),
303-
)
304-
self.do_log_stats_with_interval()
302+
with record_function_or_nullcontext("llm_genine step: record_stats"):
303+
if self.logger_manager is not None and outputs.scheduler_stats is not None:
304+
self.logger_manager.record(
305+
scheduler_stats=outputs.scheduler_stats,
306+
iteration_stats=iteration_stats,
307+
mm_cache_stats=self.processor.stat_mm_cache(),
308+
)
309+
self.do_log_stats_with_interval()
305310

306311
return processed_outputs.request_outputs
307312

0 commit comments

Comments
 (0)