diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index f6a73e99546e..1c06f0371599 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -231,3 +231,41 @@ Leverage VLLM_GC_DEBUG environment variable to debug GC costs. - VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger to log top 5 collected objects for each gc.collect + +## Lite Profiler + +The lite profiler is a lightweight, minimal-overhead profiling tool designed for capturing the time spent by the critical components of the system. + +### How It Works + +The lite profiler uses context managers to time code blocks and writes timing data to a log file with minimal overhead. After execution completes, it uses the logs from the log file to generate an aggregated summary report showing time spent in different execution stages. Currently, it's enabled for both the `vllm bench throughput` and `vllm bench latency` commands. + +### Environment Variable + +Enable lite profiling by setting the `VLLM_LITE_PROFILER_LOG_PATH` environment variable to the desired log file path: + +```bash +export VLLM_LITE_PROFILER_LOG_PATH=/tmp/vllm_lite_profile.log +``` + +### Example Commands and Usage + +#### Throughput Benchmark + +Profile throughput with the lite profiler: + +```bash +VLLM_LITE_PROFILER_LOG_PATH=/tmp/vllm_lite_profile.log \ +vllm bench throughput \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --dataset-name random \ + --num-prompts 1000 \ + --input-len 128 \ + --output-len 128 +``` + +The profiler will automatically generate a summary report at the end showing time breakdowns for: + +- **Top-level pipeline events**: schedule, model execution, output processing +- **Model events**: state updates, input preparation, forward pass, post-processing, bookkeeping +``` \ No newline at end of file diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index 7692697fe768..9de8d4fce87e 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -17,6 +17,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.sampling_params import BeamSearchParams +from vllm.utils.lite_profiler.lite_profiler import maybe_emit_lite_profiler_report def save_to_pytorch_benchmark_format( @@ -170,3 +171,6 @@ def run_to_completion(profile_dir: Optional[str] = None): with open(args.output_json, "w") as f: json.dump(results, f, indent=4) save_to_pytorch_benchmark_format(args, results) + + # Generate the lite-profiler report if enabled. + maybe_emit_lite_profiler_report() diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index b0f63fd2c722..158318fde99b 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -35,6 +35,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import merge_async_iterators +from vllm.utils.lite_profiler.lite_profiler import maybe_emit_lite_profiler_report def run_vllm( @@ -789,3 +790,6 @@ def main(args: argparse.Namespace): with open(args.output_json, "w") as f: json.dump(results, f, indent=4) save_to_pytorch_benchmark_format(args, results) + + # Generate the lite-profiler report if enabled. + maybe_emit_lite_profiler_report() diff --git a/vllm/envs.py b/vllm/envs.py index ab8548cf5066..aac6e5e45e6a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -209,6 +209,7 @@ VLLM_NCCL_INCLUDE_PATH: Optional[str] = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_LITE_PROFILER_LOG_PATH: Optional[str] = None def get_default_cache_root(): @@ -1391,6 +1392,10 @@ def get_vllm_port() -> Optional[int]: # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Log path for the lightweight timing profiler. + # If this path is set (not None), lightweight profiling will be enabled, + # providing detailed execution latency breakdown by segment. + "VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv("VLLM_LITE_PROFILER_LOG_PATH"), } # --8<-- [end:env-vars-definition] diff --git a/vllm/utils/lite_profiler/__init__.py b/vllm/utils/lite_profiler/__init__.py new file mode 100644 index 000000000000..42c3b5085a99 --- /dev/null +++ b/vllm/utils/lite_profiler/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Lightweight profiler for timing code execution with minimal overhead.""" + +from vllm.utils.lite_profiler.lite_profiler import ( + LiteProfilerScope, + maybe_emit_lite_profiler_report, +) + +__all__ = [ + "LiteProfilerScope", + "maybe_emit_lite_profiler_report", +] diff --git a/vllm/utils/lite_profiler/lite_profiler.py b/vllm/utils/lite_profiler/lite_profiler.py new file mode 100644 index 000000000000..6a89498b0530 --- /dev/null +++ b/vllm/utils/lite_profiler/lite_profiler.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Minimal helpers for opt-in lightweight timing collection.""" + +from __future__ import annotations + +import atexit +import functools +import multiprocessing +import os +import time +from types import TracebackType +from typing import TextIO + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@functools.cache +def _should_log_results() -> bool: + """Check if the current process should log results. + Only the data-parallel rank 0 engine core and worker 0 should emit logs in + multi-process deployments so that we avoid duplicating identical timing + data. + """ + process = multiprocessing.current_process() + return process.name in ("EngineCore_DP0", "VllmWorker-0") + + +# Cache for log file handle +_log_file: TextIO | None = None + + +def _write_log_entry(name: str, elapsed_ns: int) -> None: + """Write a profiler entry using cached file handle for optimal performance. + + This function implements an efficient caching approach where the file handle + is opened once and reused for all subsequent writes. This eliminates the + significant overhead of opening/closing files for every profiler entry, + which is crucial for maintaining the lightweight nature of the profiler. + + The cached file handle is automatically closed on program exit via atexit. + """ + global _log_file + + log_path = envs.VLLM_LITE_PROFILER_LOG_PATH + assert log_path is not None + + if not _should_log_results(): + return + + # Handle case where file handle was opened in parent but we're in the + # child process. The file descriptor may become invalid after fork + if _log_file is not None: + try: + # Verify if the file handle is still valid + _log_file.flush() + _log_file.tell() + except (OSError, ValueError): + # File handle is stale, clear and reopen + _log_file = None + + # Write the log entry + if _log_file is None: + directory = os.path.dirname(log_path) + if directory: + os.makedirs(directory, exist_ok=True) + # ruff: noqa: SIM115 - intentionally keeping file handle cached globally + # Currently, we are flushing the file handle after every write. This + # is done to ensure safety so that there is no data loss. + # TODO: We can optimise this further, to ensure performance overhead + # can be reduced in future. + _log_file = open(log_path, "a", buffering=1) + atexit.register(_log_file.close) + + _log_file.write(f"{name}|{elapsed_ns}\n") + + +class LiteProfilerScope: + """Lightweight context manager for timing code blocks with minimal overhead. + + This class provides a simple way to measure and log the execution time of + code blocks using Python's context manager protocol (with statement). It's + designed for high-frequency profiling with minimal performance impact. + """ + + def __init__(self, name: str) -> None: + self._name = name + self._start_time_ns: int | None = None + + def __enter__(self) -> None: + self._start_time_ns = time.perf_counter_ns() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if self._start_time_ns is not None: + elapsed_ns = time.perf_counter_ns() - self._start_time_ns + _write_log_entry(self._name, elapsed_ns) + + +def maybe_emit_lite_profiler_report() -> None: + """Generate and display a summary report of profiling data if available. + + This function serves as the main entry point for analyzing and displaying + profiling results. It checks if profiling was enabled and a log file exists, + then delegates to the lite_profiler_report module to generate statistics + like function call counts, timing distributions, and performance insights. + """ + + log_path = envs.VLLM_LITE_PROFILER_LOG_PATH + if log_path is None or not os.path.exists(log_path): + return + + # Ensure the log file is flushed and closed before generating report + global _log_file + if _log_file is not None: + _log_file.close() + _log_file = None + + from vllm.utils.lite_profiler import lite_profiler_report + + logger.info("") + logger.info("Lite profiler summary (%s):", log_path) + try: + # Generate and display the summary report + lite_profiler_report.summarize_log(log_path) + os.remove(log_path) + except Exception as exc: # pragma: no cover - avoid crashing benchmarks + logger.error("Failed to summarize lite profiler log %s: %s", log_path, exc) diff --git a/vllm/utils/lite_profiler/lite_profiler_report.py b/vllm/utils/lite_profiler/lite_profiler_report.py new file mode 100644 index 000000000000..6a12624c13db --- /dev/null +++ b/vllm/utils/lite_profiler/lite_profiler_report.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Summarize a single vLLM lite-profiler log in tabular form. + +The script consumes the pipe-separated records emitted by `vllm.lite_profiler` +It expects log lines in the format: "|" +""" + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable, Sequence + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _extract_event_us(filename: str) -> dict[str, list[int]]: + """Collect the nanosecond timings for every scope in ``filename``.""" + + all_event_ns: dict[str, list[int]] = defaultdict(list) + try: + with open(filename, encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + if not line: + continue + + # Parse the format: "scope_name|elapsed_nanoseconds" + if "|" in line: + try: + scope_name, elapsed_ns_str = line.split("|", 1) + elapsed_ns = int(elapsed_ns_str) + all_event_ns[scope_name].append(elapsed_ns) + except (ValueError, IndexError): + # Skip malformed lines + continue + except FileNotFoundError: + raise FileNotFoundError(f"Lite-profiler log not found: {filename}") from None + return all_event_ns + + +def _sum_events(event_ns: dict[str, list[int]]) -> dict[str, int]: + return {event: sum(values) for event, values in event_ns.items()} + + +def _format_duration_ns(value_ns: int, total_ns: int) -> str: + seconds = value_ns / 1e9 if value_ns else 0.0 + percent = (value_ns * 100.0 / total_ns) if total_ns else 0.0 + return f"{seconds:.3f}s ({percent:.2f}%)" + + +def _render_table( + title: str, headers: Sequence[str], rows: Iterable[Sequence[str]] +) -> None: + table = [list(headers)] + [list(row) for row in rows] + widths = [max(len(row[i]) for row in table) for i in range(len(headers))] + + logger.info("") + logger.info(title) + separator = "-" * (sum(widths) + 3 * (len(widths) - 1)) + logger.info(separator) + + def _fmt(row: Sequence[str]) -> str: + return " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + + logger.info(_fmt(table[0])) + logger.info(" | ".join("-" * w for w in widths)) + for row in table[1:]: + logger.info(_fmt(row)) + + +TOP_EVENTS = [ + "Step:Schedule", + "Step:Model", + "Step:Output", +] + +MODEL_EVENTS = [ + "Model:UpdateState", + "Model:PrepareInput", + "Model:Forward", + "Model:Postprocess", + "Model:Bookkeep", +] + + +def _compute_table_rows( + event_ns_sum: dict[str, int], + events: Sequence[str], +) -> list[str]: + total_ns = sum(event_ns_sum.get(event, 0) for event in events) + cells = [] + for event in events: + cells.append(_format_duration_ns(event_ns_sum.get(event, 0), total_ns)) + total_seconds = total_ns / 1e9 if total_ns else 0.0 + cells.append(f"{total_seconds:.3f}s") + return cells + + +def _print_breakdown_tables(event_ns_sum: dict[str, int]) -> None: + for title, events in ( + ("Top-level pipeline events", TOP_EVENTS), + ("Model events breakdown (only includes the main key events)", MODEL_EVENTS), + ): + headers = [*events, "TOTAL"] + rows = [_compute_table_rows(event_ns_sum, events)] + _render_table(title, headers, rows) + + +def summarize_log(log_path: str) -> None: + elapsed_ns = _extract_event_us(log_path) + event_ns_sum = _sum_events(elapsed_ns) + _print_breakdown_tables(event_ns_sum) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e6474d91ffed..c054cb91ac12 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -65,6 +65,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -317,14 +318,19 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - scheduler_output = self.scheduler.schedule() - model_output = self.execute_model_with_error_logging( - self.model_executor.execute_model, # type: ignore - scheduler_output, - ) - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("Step:Schedule"): + scheduler_output = self.scheduler.schedule() + + with record_function_or_nullcontext("Step:Model"): + model_output = self.execute_model_with_error_logging( + self.model_executor.execute_model, # type: ignore + scheduler_output, + ) + + with record_function_or_nullcontext("Step:Output"): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # type: ignore return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 925943262894..3c21935590a0 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -32,6 +32,7 @@ get_tcp_uri, kill_process_tree, ) +from vllm.utils.lite_profiler.lite_profiler import LiteProfilerScope if TYPE_CHECKING: import numpy as np @@ -380,7 +381,9 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: return _PROFILER_FUNC(name) func = contextlib.nullcontext - if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: + if envs.VLLM_LITE_PROFILER_LOG_PATH: + func = LiteProfilerScope + elif envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: func = record_function elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: import nvtx diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ec824f6d6bf5..09a3abbe03fb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2413,10 +2413,11 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: - with record_function_or_nullcontext("Preprocess"): + with record_function_or_nullcontext("Model:Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. - self._update_states(scheduler_output) + with record_function_or_nullcontext("Model:UpdateState"): + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -2433,17 +2434,18 @@ def execute_model( ) # Prepare the decoder inputs. - ( - attn_metadata, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np, - spec_decode_common_attn_metadata, - max_query_len, - ubatch_slices, - num_tokens_across_dp, - use_cascade_attn, - ) = self._prepare_inputs(scheduler_output) + with record_function_or_nullcontext("Model:PrepareInput"): + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_across_dp, + use_cascade_attn, + ) = self._prepare_inputs(scheduler_output) dp_rank = self.parallel_config.data_parallel_rank if ubatch_slices: @@ -2502,7 +2504,7 @@ def execute_model( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ), - record_function_or_nullcontext("Forward"), + record_function_or_nullcontext("Model:Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): model_output = self._model_forward( @@ -2513,7 +2515,7 @@ def execute_model( **model_kwargs, ) - with record_function_or_nullcontext("Postprocess"): + with record_function_or_nullcontext("Model:Postprocess"): if self.use_aux_hidden_state_outputs: # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output @@ -2619,7 +2621,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("Model:Bookkeep"): ( num_nans_in_logits, logprobs_lists,