From 99967027ab5fc0cf302e53b7f9673c8175743cb4 Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Sat, 11 Oct 2025 13:46:11 -0700 Subject: [PATCH 1/8] Lite profiler implementation Signed-off-by: Naman Lalit --- vllm/benchmarks/latency.py | 4 + vllm/benchmarks/serve.py | 4 + vllm/benchmarks/throughput.py | 4 + vllm/envs.py | 5 + vllm/utils/lite_profiler.py | 144 +++++++++++++++++++++++++++++ vllm/utils/lite_profiler_report.py | 121 ++++++++++++++++++++++++ vllm/v1/engine/core.py | 21 +++-- vllm/v1/utils.py | 5 +- vllm/v1/worker/gpu_model_runner.py | 40 ++++---- 9 files changed, 320 insertions(+), 28 deletions(-) create mode 100644 vllm/utils/lite_profiler.py create mode 100644 vllm/utils/lite_profiler_report.py diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index b4f1751837f..2ed24fa045d 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -17,6 +17,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.sampling_params import BeamSearchParams +from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report def save_to_pytorch_benchmark_format( @@ -170,3 +171,6 @@ def run_to_completion(profile_dir: str | None = None): with open(args.output_json, "w") as f: json.dump(results, f, indent=4) save_to_pytorch_benchmark_format(args, results) + + # Generate the lite-profiler report if enabled. + maybe_emit_lite_profiler_report() diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 71d136d61ce..e79025daccb 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -48,6 +48,7 @@ from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -1511,4 +1512,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) + # Generate the lite-profiler report if enabled. + maybe_emit_lite_profiler_report() + return result_json diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 866365ac18e..ee6b769020a 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -35,6 +35,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils.async_utils import merge_async_iterators +from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report def run_vllm( @@ -789,3 +790,6 @@ def main(args: argparse.Namespace): with open(args.output_json, "w") as f: json.dump(results, f, indent=4) save_to_pytorch_benchmark_format(args, results) + + # Generate the lite-profiler report if enabled. + maybe_emit_lite_profiler_report() diff --git a/vllm/envs.py b/vllm/envs.py index e91d8d03321..a72bb9ae34c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -214,6 +214,7 @@ VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False + VLLM_LITE_PROFILER_LOG_PATH: str | None = None def get_default_cache_root(): @@ -1384,6 +1385,10 @@ def get_vllm_port() -> int | None: "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv( "VLLM_DISABLE_SHARED_EXPERTS_STREAM", False ), + # Log path for the lightweight timing profiler. + # If this path is set (not None), lightweight profiling will be enabled, + # providing detailed analysis of the execution time for each function call. + "VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv("VLLM_LITE_PROFILER_LOG_PATH"), } # --8<-- [end:env-vars-definition] diff --git a/vllm/utils/lite_profiler.py b/vllm/utils/lite_profiler.py new file mode 100644 index 00000000000..d83b6a3e57c --- /dev/null +++ b/vllm/utils/lite_profiler.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Minimal helpers for opt-in lightweight timing collection.""" + +from __future__ import annotations + +import atexit +import multiprocessing +import os +import time +from contextlib import suppress +from types import TracebackType +from typing import TextIO + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _should_log_results() -> bool: + """Check if the current process should log results. + Only the data-parallel rank 0 engine core and worker 0 should emit logs in + multi-process deployments so that we avoid duplicating identical timing + data. + """ + process = multiprocessing.current_process() + return process.name in ("EngineCore_DP0", "VllmWorker-0") + + +# Cache for log file handle +_log_file: TextIO | None = None +log_results = _should_log_results() + + +def _write_log_entry(name: str, elapsed_us: int) -> None: + """Write a profiler entry using cached file handle for optimal performance. + + This function implements an efficient caching approach where the file handle + is opened once and reused for all subsequent writes. This eliminates the + significant overhead of opening/closing files for every profiler entry, + which is crucial for maintaining the lightweight nature of the profiler. + + The cached file handle is automatically closed on program exit via atexit. + """ + global _log_file + _LOG_PATH = envs.VLLM_LITE_PROFILER_LOG_PATH + + if not log_results or _LOG_PATH is None: + return + + # Handle case where file handle was opened in parent but we're in the + # child process. The file descriptor may become invalid after fork + if _log_file is not None: + try: + # Verify if the file handle is still valid + _log_file.tell() + except (OSError, ValueError): + # File handle is stale, clear and reopen + _log_file = None + + # Write the log entry + log_line = f"{name}|{elapsed_us}\n" + if _log_file is None: + directory = os.path.dirname(_LOG_PATH) + if directory: + os.makedirs(directory, exist_ok=True) + # ruff: noqa: SIM115 - intentionally keeping file handle cached globally + _log_file = open(_LOG_PATH, "a", buffering=50000) + atexit.register(_log_file.close) + + _log_file.write(log_line) + + +class LiteScope: + """Lightweight context manager for timing code blocks with minimal overhead. + + This class provides a simple way to measure and log the execution time of + code blocks using Python's context manager protocol (with statement). It's + designed for high-frequency profiling with minimal performance impact. + """ + + def __init__(self, name: str) -> None: + self._name = name + self._start_time: int | None = None + + def __enter__(self) -> None: + self._start_time = time.perf_counter_ns() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + if self._start_time is not None and exc_type is None: + elapsed_ns = time.perf_counter_ns() - self._start_time + # Use integer microseconds for better performance + elapsed_us = elapsed_ns // 1000 + _write_log_entry(self._name, elapsed_us) + return False + + +def maybe_emit_lite_profiler_report() -> None: + """Generate and display a summary report of profiling data if available. + + This function serves as the main entry point for analyzing and displaying + profiling results. It checks if profiling was enabled and a log file exists, + then delegates to the lite_profiler_report module to generate statistics + like function call counts, timing distributions, and performance insights. + """ + + log_path = envs.VLLM_LITE_PROFILER_LOG_PATH + if log_path is None: + return + + if not os.path.exists(log_path): + logger.warning( + "Lite profiler log not found. Ensure the profiled process sets " + "the expected path." + ) + return + + try: + from vllm.utils import lite_profiler_report + except Exception as exc: # pragma: no cover - import error should not crash + logger.error("Failed to import lite profiler report helper: %s", exc) + return + + logger.info("") + logger.info("Lite profiler summary (%s):", log_path) + try: + # Generate and display the summary report + lite_profiler_report.summarize_log(log_path) + + # Clear the log file to avoid accumulating data from multiple runs + with suppress(OSError): + directory = os.path.dirname(log_path) + if directory: + os.makedirs(directory, exist_ok=True) + with open(log_path, "w"): + pass + except Exception as exc: # pragma: no cover - avoid crashing benchmarks + logger.error("Failed to summarize lite profiler log %s: %s", log_path, exc) diff --git a/vllm/utils/lite_profiler_report.py b/vllm/utils/lite_profiler_report.py new file mode 100644 index 00000000000..0d23db5aaf0 --- /dev/null +++ b/vllm/utils/lite_profiler_report.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Summarize a single vLLM lite-profiler log in tabular form. + +The script consumes the pipe-separated records emitted by `vllm.lite_profiler` +It expects log lines in the format: "|" +""" + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable, Sequence + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _extract_event_us(filename: str) -> dict[str, list[int]]: + """Collect the microsecond timings for every scope in ``filenames``.""" + + all_event_us: dict[str, list[int]] = defaultdict(list) + try: + with open(filename, encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + if not line: + continue + + # Parse the format: "scope_name|elapsed_microseconds" + if "|" in line: + try: + scope_name, elapsed_us_str = line.split("|", 1) + elapsed_us = int(elapsed_us_str) + all_event_us[scope_name].append(elapsed_us) + except (ValueError, IndexError): + # Skip malformed lines + continue + except FileNotFoundError: + raise FileNotFoundError(f"Lite-profiler log not found: {filename}") from None + return all_event_us + + +def _sum_events(event_us: dict[str, list[int]]) -> dict[str, int]: + return {event: sum(values) for event, values in event_us.items()} + + +def _format_duration_us(value_us: int, total_us: int) -> str: + # Convert microseconds to seconds + seconds = value_us / 1e6 if value_us else 0.0 + percent = (value_us * 100.0 / total_us) if total_us else 0.0 + return f"{seconds:.2f}s ({percent:.2f}%)" + + +def _render_table( + title: str, headers: Sequence[str], rows: Iterable[Sequence[str]] +) -> None: + table = [list(headers)] + [list(row) for row in rows] + widths = [max(len(row[i]) for row in table) for i in range(len(headers))] + + logger.info("") + logger.info(title) + separator = "-" * (sum(widths) + 3 * (len(widths) - 1)) + logger.info(separator) + + def _fmt(row: Sequence[str]) -> str: + return " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + + logger.info(_fmt(table[0])) + logger.info(" | ".join("-" * w for w in widths)) + for row in table[1:]: + logger.info(_fmt(row)) + + +TOP_EVENTS = [ + "Input:Process", + "Step:Schedule", + "Step:Model", + "Step:Output", +] + +MODEL_EVENTS = [ + "Model:UpdateState", + "Model:PrepareInput", + "Model:Forward", + "Model:Postprocess", + "Model:Sample", + "Model:Bookkeep", + "Model:EPLB", + "Model:Draft", +] + + +def _compute_table_rows( + event_us_sum: dict[str, int], + events: Sequence[str], +) -> list[str]: + total_us = sum(event_us_sum.get(event, 0) for event in events) + cells = [] + for event in events: + cells.append(_format_duration_us(event_us_sum.get(event, 0), total_us)) + # Convert microseconds to seconds + total_seconds = total_us / 1_000_000 if total_us else 0.0 + cells.append(f"{total_seconds:.2f}s") + return cells + + +def _print_breakdown_tables(event_us_sum: dict[str, int]) -> None: + for title, events in ( + ("Top-level pipeline events", TOP_EVENTS), + ("Model events breakdown (only includes the main key events)", MODEL_EVENTS), + ): + headers = [*events, "TOTAL"] + rows = [_compute_table_rows(event_us_sum, events)] + _render_table(title, headers, rows) + + +def summarize_log(log_path: str) -> None: + event_us = _extract_event_us(log_path) + event_us_sum = _sum_events(event_us) + _print_breakdown_tables(event_us_sum) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 27cf2fbe8c3..9bdd38562c8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -67,6 +67,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -315,14 +316,16 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - scheduler_output = self.scheduler.schedule() - - with self.log_error_detail(scheduler_output): + with record_function_or_nullcontext("Step:Schedule"): + scheduler_output = self.scheduler.schedule() + + with self.log_error_detail(scheduler_output), record_function_or_nullcontext("Step:Model"): model_output = self.model_executor.execute_model(scheduler_output) - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("Step:Output"): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # type: ignore return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 @@ -814,7 +817,8 @@ def _process_input_queue(self): logger.debug("EngineCore waiting for work.") waited = True req = self.input_queue.get() - self._handle_client_request(*req) + with record_function_or_nullcontext("Input:Process"): + self._handle_client_request(*req) if waited: logger.debug("EngineCore loop active.") @@ -822,7 +826,8 @@ def _process_input_queue(self): # Handle any more client requests. while not self.input_queue.empty(): req = self.input_queue.get_nowait() - self._handle_client_request(*req) + with record_function_or_nullcontext("Input:Process"): + self._handle_client_request(*req) def _process_engine_step(self) -> bool: """Called only when there are unfinished local requests.""" diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index e8fa8126646..49ff8b12d64 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -26,6 +26,7 @@ from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message from vllm.utils import kill_process_tree +from vllm.utils.lite_profiler import LiteScope from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri if TYPE_CHECKING: @@ -387,7 +388,9 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: return _PROFILER_FUNC(name) func = contextlib.nullcontext - if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: + if envs.VLLM_LITE_PROFILER_LOG_PATH: + func = LiteScope + elif envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: func = record_function elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: import nvtx diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b2d99a0ec69..8f589d479ba 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2423,10 +2423,11 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: - with record_function_or_nullcontext("Preprocess"): + with record_function_or_nullcontext("Model:Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. - self._update_states(scheduler_output) + with record_function_or_nullcontext("Model:UpdateState"): + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -2443,17 +2444,18 @@ def execute_model( ) # Prepare the decoder inputs. - ( - attn_metadata, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np, - spec_decode_common_attn_metadata, - max_query_len, - ubatch_slices, - num_tokens_across_dp, - use_cascade_attn, - ) = self._prepare_inputs(scheduler_output) + with record_function_or_nullcontext("Model:PrepareInput"): + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_across_dp, + use_cascade_attn, + ) = self._prepare_inputs(scheduler_output) dp_rank = self.parallel_config.data_parallel_rank if ubatch_slices: @@ -2514,7 +2516,7 @@ def execute_model( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ), - record_function_or_nullcontext("Forward"), + record_function_or_nullcontext("Model:Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): model_output = self._model_forward( @@ -2525,7 +2527,7 @@ def execute_model( **model_kwargs, ) - with record_function_or_nullcontext("Postprocess"): + with record_function_or_nullcontext("Model:Postprocess"): if self.use_aux_hidden_state_outputs: # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output @@ -2586,12 +2588,12 @@ def execute_model( if scheduler_output.structured_output_request_ids: apply_grammar_bitmask(scheduler_output, self.input_batch, logits) - with record_function_or_nullcontext("Sample"): + with record_function_or_nullcontext("Model:Sample"): sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): + with record_function_or_nullcontext("Model:Draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, sampled_token_ids, @@ -2629,7 +2631,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("Model:Bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2655,7 +2657,7 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - with record_function_or_nullcontext("EPLB"): + with record_function_or_nullcontext("Model:EPLB"): self.eplb_step() output = ModelRunnerOutput( From 7534a404c6423e94386edb21cf6f8581b2cd1c58 Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Mon, 13 Oct 2025 09:56:36 -0700 Subject: [PATCH 2/8] Addressed review comments Signed-off-by: Naman Lalit --- vllm/envs.py | 6 ++++-- vllm/utils/lite_profiler.py | 23 ++++------------------- vllm/utils/lite_profiler_report.py | 6 +++--- 3 files changed, 11 insertions(+), 24 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index a72bb9ae34c..cf727f99db7 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1387,8 +1387,10 @@ def get_vllm_port() -> int | None: ), # Log path for the lightweight timing profiler. # If this path is set (not None), lightweight profiling will be enabled, - # providing detailed analysis of the execution time for each function call. - "VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv("VLLM_LITE_PROFILER_LOG_PATH"), + # providing detailed execution latency breakdown by segment. + "VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv( + "VLLM_LITE_PROFILER_LOG_PATH", None + ), } # --8<-- [end:env-vars-definition] diff --git a/vllm/utils/lite_profiler.py b/vllm/utils/lite_profiler.py index d83b6a3e57c..d5559ff1cf3 100644 --- a/vllm/utils/lite_profiler.py +++ b/vllm/utils/lite_profiler.py @@ -8,7 +8,6 @@ import multiprocessing import os import time -from contextlib import suppress from types import TracebackType from typing import TextIO @@ -30,7 +29,6 @@ def _should_log_results() -> bool: # Cache for log file handle _log_file: TextIO | None = None -log_results = _should_log_results() def _write_log_entry(name: str, elapsed_us: int) -> None: @@ -46,7 +44,7 @@ def _write_log_entry(name: str, elapsed_us: int) -> None: global _log_file _LOG_PATH = envs.VLLM_LITE_PROFILER_LOG_PATH - if not log_results or _LOG_PATH is None: + if not _should_log_results() or _LOG_PATH is None: return # Handle case where file handle was opened in parent but we're in the @@ -66,7 +64,7 @@ def _write_log_entry(name: str, elapsed_us: int) -> None: if directory: os.makedirs(directory, exist_ok=True) # ruff: noqa: SIM115 - intentionally keeping file handle cached globally - _log_file = open(_LOG_PATH, "a", buffering=50000) + _log_file = open(_LOG_PATH, "w", buffering=50000) atexit.register(_log_file.close) _log_file.write(log_line) @@ -92,13 +90,12 @@ def __exit__( exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, - ) -> bool: + ) -> None: if self._start_time is not None and exc_type is None: elapsed_ns = time.perf_counter_ns() - self._start_time # Use integer microseconds for better performance elapsed_us = elapsed_ns // 1000 _write_log_entry(self._name, elapsed_us) - return False def maybe_emit_lite_profiler_report() -> None: @@ -121,24 +118,12 @@ def maybe_emit_lite_profiler_report() -> None: ) return - try: - from vllm.utils import lite_profiler_report - except Exception as exc: # pragma: no cover - import error should not crash - logger.error("Failed to import lite profiler report helper: %s", exc) - return + from vllm.utils import lite_profiler_report logger.info("") logger.info("Lite profiler summary (%s):", log_path) try: # Generate and display the summary report lite_profiler_report.summarize_log(log_path) - - # Clear the log file to avoid accumulating data from multiple runs - with suppress(OSError): - directory = os.path.dirname(log_path) - if directory: - os.makedirs(directory, exist_ok=True) - with open(log_path, "w"): - pass except Exception as exc: # pragma: no cover - avoid crashing benchmarks logger.error("Failed to summarize lite profiler log %s: %s", log_path, exc) diff --git a/vllm/utils/lite_profiler_report.py b/vllm/utils/lite_profiler_report.py index 0d23db5aaf0..2f552fa9c33 100644 --- a/vllm/utils/lite_profiler_report.py +++ b/vllm/utils/lite_profiler_report.py @@ -17,7 +17,7 @@ def _extract_event_us(filename: str) -> dict[str, list[int]]: - """Collect the microsecond timings for every scope in ``filenames``.""" + """Collect the microsecond timings for every scope in ``filename``.""" all_event_us: dict[str, list[int]] = defaultdict(list) try: @@ -49,7 +49,7 @@ def _format_duration_us(value_us: int, total_us: int) -> str: # Convert microseconds to seconds seconds = value_us / 1e6 if value_us else 0.0 percent = (value_us * 100.0 / total_us) if total_us else 0.0 - return f"{seconds:.2f}s ({percent:.2f}%)" + return f"{seconds:.3f}s ({percent:.2f}%)" def _render_table( @@ -101,7 +101,7 @@ def _compute_table_rows( cells.append(_format_duration_us(event_us_sum.get(event, 0), total_us)) # Convert microseconds to seconds total_seconds = total_us / 1_000_000 if total_us else 0.0 - cells.append(f"{total_seconds:.2f}s") + cells.append(f"{total_seconds:.3f}s") return cells From e94d0b18abf4e91d18dcb997ccae7ced74a3b434 Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Tue, 14 Oct 2025 16:07:35 -0700 Subject: [PATCH 3/8] address comments and add documentation Signed-off-by: Naman Lalit --- docs/contributing/profiling.md | 38 +++++++++++++++ vllm/benchmarks/latency.py | 2 +- vllm/benchmarks/serve.py | 4 -- vllm/benchmarks/throughput.py | 2 +- .../{ => lite_profiler}/lite_profiler.py | 24 ++++++---- .../lite_profiler_report.py | 48 ++++++++----------- vllm/v1/engine/core.py | 6 +-- vllm/v1/utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 8 ++-- 9 files changed, 84 insertions(+), 50 deletions(-) rename vllm/utils/{ => lite_profiler}/lite_profiler.py (87%) rename vllm/utils/{ => lite_profiler}/lite_profiler_report.py (64%) diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index fed286f4b63..e33197d4a75 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -235,3 +235,41 @@ Leverage VLLM_GC_DEBUG environment variable to debug GC costs. - VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger to log top 5 collected objects for each gc.collect + +## Lite Profiler + +The lite profiler is a lightweight, minimal-overhead profiling tool designed for capturing the time spent by the critical components of the system. + +### How It Works + +The lite profiler uses context managers to time code blocks and writes timing data to a log file with minimal overhead (~50KB buffer). After execution completes, it generates an aggregate summary report showing time spent in different pipeline stages. + +### Environment Variable + +Enable lite profiling by setting the `VLLM_LITE_PROFILER_LOG_PATH` environment variable to the desired log file path: + +```bash +export VLLM_LITE_PROFILER_LOG_PATH=/tmp/vllm_lite_profile.log +``` + +### Example Commands and Usage + +#### Throughput Benchmark + +Profile throughput with the lite profiler: + +```bash +VLLM_LITE_PROFILER_LOG_PATH=/tmp/vllm_lite_profile.log \ +vllm bench throughput \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --dataset-name random \ + --num-prompts 1000 \ + --input-len 128 \ + --output-len 128 +``` + +The profiler will automatically generate a summary report at the end showing time breakdowns for: + +- **Top-level pipeline events**: Schedule, Model execution, Output processing +- **Model events**: State updates, input preparation, forward pass, postprocessing, bookkeeping +``` \ No newline at end of file diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index 2ed24fa045d..b404d910884 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -17,7 +17,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.sampling_params import BeamSearchParams -from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report +from vllm.utils.lite_profiler.lite_profiler import maybe_emit_lite_profiler_report def save_to_pytorch_benchmark_format( diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index e79025daccb..71d136d61ce 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -48,7 +48,6 @@ from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -1512,7 +1511,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) - # Generate the lite-profiler report if enabled. - maybe_emit_lite_profiler_report() - return result_json diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index ee6b769020a..07ea80409ec 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -35,7 +35,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils.async_utils import merge_async_iterators -from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report +from vllm.utils.lite_profiler.lite_profiler import maybe_emit_lite_profiler_report def run_vllm( diff --git a/vllm/utils/lite_profiler.py b/vllm/utils/lite_profiler/lite_profiler.py similarity index 87% rename from vllm/utils/lite_profiler.py rename to vllm/utils/lite_profiler/lite_profiler.py index d5559ff1cf3..13b1b7d4ae6 100644 --- a/vllm/utils/lite_profiler.py +++ b/vllm/utils/lite_profiler/lite_profiler.py @@ -31,7 +31,7 @@ def _should_log_results() -> bool: _log_file: TextIO | None = None -def _write_log_entry(name: str, elapsed_us: int) -> None: +def _write_log_entry(name: str, elapsed_ns: int) -> None: """Write a profiler entry using cached file handle for optimal performance. This function implements an efficient caching approach where the file handle @@ -43,8 +43,9 @@ def _write_log_entry(name: str, elapsed_us: int) -> None: """ global _log_file _LOG_PATH = envs.VLLM_LITE_PROFILER_LOG_PATH + assert _LOG_PATH is not None - if not _should_log_results() or _LOG_PATH is None: + if not _should_log_results(): return # Handle case where file handle was opened in parent but we're in the @@ -52,19 +53,20 @@ def _write_log_entry(name: str, elapsed_us: int) -> None: if _log_file is not None: try: # Verify if the file handle is still valid + _log_file.flush() _log_file.tell() except (OSError, ValueError): # File handle is stale, clear and reopen _log_file = None # Write the log entry - log_line = f"{name}|{elapsed_us}\n" + log_line = f"{name}|{elapsed_ns}\n" if _log_file is None: directory = os.path.dirname(_LOG_PATH) if directory: os.makedirs(directory, exist_ok=True) # ruff: noqa: SIM115 - intentionally keeping file handle cached globally - _log_file = open(_LOG_PATH, "w", buffering=50000) + _log_file = open(_LOG_PATH, "a", buffering=50000) atexit.register(_log_file.close) _log_file.write(log_line) @@ -93,9 +95,7 @@ def __exit__( ) -> None: if self._start_time is not None and exc_type is None: elapsed_ns = time.perf_counter_ns() - self._start_time - # Use integer microseconds for better performance - elapsed_us = elapsed_ns // 1000 - _write_log_entry(self._name, elapsed_us) + _write_log_entry(self._name, elapsed_ns) def maybe_emit_lite_profiler_report() -> None: @@ -111,6 +111,13 @@ def maybe_emit_lite_profiler_report() -> None: if log_path is None: return + # Ensure the log file is flushed and closed before generating report + global _log_file + if _log_file is not None: + _log_file.flush() + _log_file.close() + _log_file = None + if not os.path.exists(log_path): logger.warning( "Lite profiler log not found. Ensure the profiled process sets " @@ -118,12 +125,13 @@ def maybe_emit_lite_profiler_report() -> None: ) return - from vllm.utils import lite_profiler_report + from vllm.utils.lite_profiler import lite_profiler_report logger.info("") logger.info("Lite profiler summary (%s):", log_path) try: # Generate and display the summary report lite_profiler_report.summarize_log(log_path) + os.remove(log_path) except Exception as exc: # pragma: no cover - avoid crashing benchmarks logger.error("Failed to summarize lite profiler log %s: %s", log_path, exc) diff --git a/vllm/utils/lite_profiler_report.py b/vllm/utils/lite_profiler/lite_profiler_report.py similarity index 64% rename from vllm/utils/lite_profiler_report.py rename to vllm/utils/lite_profiler/lite_profiler_report.py index 2f552fa9c33..6a12624c13d 100644 --- a/vllm/utils/lite_profiler_report.py +++ b/vllm/utils/lite_profiler/lite_profiler_report.py @@ -17,9 +17,9 @@ def _extract_event_us(filename: str) -> dict[str, list[int]]: - """Collect the microsecond timings for every scope in ``filename``.""" + """Collect the nanosecond timings for every scope in ``filename``.""" - all_event_us: dict[str, list[int]] = defaultdict(list) + all_event_ns: dict[str, list[int]] = defaultdict(list) try: with open(filename, encoding="utf-8") as f: for raw_line in f: @@ -27,28 +27,27 @@ def _extract_event_us(filename: str) -> dict[str, list[int]]: if not line: continue - # Parse the format: "scope_name|elapsed_microseconds" + # Parse the format: "scope_name|elapsed_nanoseconds" if "|" in line: try: - scope_name, elapsed_us_str = line.split("|", 1) - elapsed_us = int(elapsed_us_str) - all_event_us[scope_name].append(elapsed_us) + scope_name, elapsed_ns_str = line.split("|", 1) + elapsed_ns = int(elapsed_ns_str) + all_event_ns[scope_name].append(elapsed_ns) except (ValueError, IndexError): # Skip malformed lines continue except FileNotFoundError: raise FileNotFoundError(f"Lite-profiler log not found: {filename}") from None - return all_event_us + return all_event_ns -def _sum_events(event_us: dict[str, list[int]]) -> dict[str, int]: - return {event: sum(values) for event, values in event_us.items()} +def _sum_events(event_ns: dict[str, list[int]]) -> dict[str, int]: + return {event: sum(values) for event, values in event_ns.items()} -def _format_duration_us(value_us: int, total_us: int) -> str: - # Convert microseconds to seconds - seconds = value_us / 1e6 if value_us else 0.0 - percent = (value_us * 100.0 / total_us) if total_us else 0.0 +def _format_duration_ns(value_ns: int, total_ns: int) -> str: + seconds = value_ns / 1e9 if value_ns else 0.0 + percent = (value_ns * 100.0 / total_ns) if total_ns else 0.0 return f"{seconds:.3f}s ({percent:.2f}%)" @@ -73,7 +72,6 @@ def _fmt(row: Sequence[str]) -> str: TOP_EVENTS = [ - "Input:Process", "Step:Schedule", "Step:Model", "Step:Output", @@ -84,38 +82,34 @@ def _fmt(row: Sequence[str]) -> str: "Model:PrepareInput", "Model:Forward", "Model:Postprocess", - "Model:Sample", "Model:Bookkeep", - "Model:EPLB", - "Model:Draft", ] def _compute_table_rows( - event_us_sum: dict[str, int], + event_ns_sum: dict[str, int], events: Sequence[str], ) -> list[str]: - total_us = sum(event_us_sum.get(event, 0) for event in events) + total_ns = sum(event_ns_sum.get(event, 0) for event in events) cells = [] for event in events: - cells.append(_format_duration_us(event_us_sum.get(event, 0), total_us)) - # Convert microseconds to seconds - total_seconds = total_us / 1_000_000 if total_us else 0.0 + cells.append(_format_duration_ns(event_ns_sum.get(event, 0), total_ns)) + total_seconds = total_ns / 1e9 if total_ns else 0.0 cells.append(f"{total_seconds:.3f}s") return cells -def _print_breakdown_tables(event_us_sum: dict[str, int]) -> None: +def _print_breakdown_tables(event_ns_sum: dict[str, int]) -> None: for title, events in ( ("Top-level pipeline events", TOP_EVENTS), ("Model events breakdown (only includes the main key events)", MODEL_EVENTS), ): headers = [*events, "TOTAL"] - rows = [_compute_table_rows(event_us_sum, events)] + rows = [_compute_table_rows(event_ns_sum, events)] _render_table(title, headers, rows) def summarize_log(log_path: str) -> None: - event_us = _extract_event_us(log_path) - event_us_sum = _sum_events(event_us) - _print_breakdown_tables(event_us_sum) + elapsed_ns = _extract_event_us(log_path) + event_ns_sum = _sum_events(elapsed_ns) + _print_breakdown_tables(event_ns_sum) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9bdd38562c8..9cbc0f681dc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -817,8 +817,7 @@ def _process_input_queue(self): logger.debug("EngineCore waiting for work.") waited = True req = self.input_queue.get() - with record_function_or_nullcontext("Input:Process"): - self._handle_client_request(*req) + self._handle_client_request(*req) if waited: logger.debug("EngineCore loop active.") @@ -826,8 +825,7 @@ def _process_input_queue(self): # Handle any more client requests. while not self.input_queue.empty(): req = self.input_queue.get_nowait() - with record_function_or_nullcontext("Input:Process"): - self._handle_client_request(*req) + self._handle_client_request(*req) def _process_engine_step(self) -> bool: """Called only when there are unfinished local requests.""" diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 49ff8b12d64..00c84ccf3a2 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -26,7 +26,7 @@ from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message from vllm.utils import kill_process_tree -from vllm.utils.lite_profiler import LiteScope +from vllm.utils.lite_profiler.lite_profiler import LiteScope from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri if TYPE_CHECKING: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8f589d479ba..9ae9c75a6e0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2588,12 +2588,12 @@ def execute_model( if scheduler_output.structured_output_request_ids: apply_grammar_bitmask(scheduler_output, self.input_batch, logits) - with record_function_or_nullcontext("Model:Sample"): + with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Model:Draft"): + with record_function_or_nullcontext("Draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, sampled_token_ids, @@ -2631,7 +2631,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Model:Bookkeep"): + with record_function_or_nullcontext("Bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2657,7 +2657,7 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - with record_function_or_nullcontext("Model:EPLB"): + with record_function_or_nullcontext("EPLB"): self.eplb_step() output = ModelRunnerOutput( From b3d86565ef50541db8dee31e7f9524c15d583c30 Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Tue, 14 Oct 2025 16:45:10 -0700 Subject: [PATCH 4/8] fix documentation error Signed-off-by: Naman Lalit --- vllm/utils/lite_profiler/__init__.py | 13 +++++++++++++ vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 vllm/utils/lite_profiler/__init__.py diff --git a/vllm/utils/lite_profiler/__init__.py b/vllm/utils/lite_profiler/__init__.py new file mode 100644 index 00000000000..cdd3687e122 --- /dev/null +++ b/vllm/utils/lite_profiler/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Lightweight profiler for timing code execution with minimal overhead.""" + +from vllm.utils.lite_profiler.lite_profiler import ( + LiteScope, + maybe_emit_lite_profiler_report, +) + +__all__ = [ + "LiteScope", + "maybe_emit_lite_profiler_report", +] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9ae9c75a6e0..e81017039f3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2631,7 +2631,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("Model:Bookkeep"): ( num_nans_in_logits, logprobs_lists, From d423cb6a301efdc72b0a3f30fa5b0888b6c55221 Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Wed, 15 Oct 2025 11:37:04 -0700 Subject: [PATCH 5/8] add line buffer to avoid data leakage Signed-off-by: Naman Lalit --- vllm/utils/lite_profiler/lite_profiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/utils/lite_profiler/lite_profiler.py b/vllm/utils/lite_profiler/lite_profiler.py index 13b1b7d4ae6..a6ecda29596 100644 --- a/vllm/utils/lite_profiler/lite_profiler.py +++ b/vllm/utils/lite_profiler/lite_profiler.py @@ -66,7 +66,7 @@ def _write_log_entry(name: str, elapsed_ns: int) -> None: if directory: os.makedirs(directory, exist_ok=True) # ruff: noqa: SIM115 - intentionally keeping file handle cached globally - _log_file = open(_LOG_PATH, "a", buffering=50000) + _log_file = open(_LOG_PATH, "a", buffering=1) atexit.register(_log_file.close) _log_file.write(log_line) @@ -114,7 +114,6 @@ def maybe_emit_lite_profiler_report() -> None: # Ensure the log file is flushed and closed before generating report global _log_file if _log_file is not None: - _log_file.flush() _log_file.close() _log_file = None From 310a4e114413523db344edb00e316944e7805b72 Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Wed, 15 Oct 2025 14:05:19 -0700 Subject: [PATCH 6/8] Update documentation and add comments Signed-off-by: Naman Lalit --- docs/contributing/profiling.md | 2 +- vllm/utils/lite_profiler/lite_profiler.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index e33197d4a75..047165d7073 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -242,7 +242,7 @@ The lite profiler is a lightweight, minimal-overhead profiling tool designed for ### How It Works -The lite profiler uses context managers to time code blocks and writes timing data to a log file with minimal overhead (~50KB buffer). After execution completes, it generates an aggregate summary report showing time spent in different pipeline stages. +The lite profiler uses context managers to time code blocks and writes timing data to a log file with minimal overhead. After execution completes, it uses the logs from the log file to generate an aggregate summary report showing time spent in different pipeline stages. Currently, its enabled for both the `vllm bench throughput` and `vllm bench latency` commands. ### Environment Variable diff --git a/vllm/utils/lite_profiler/lite_profiler.py b/vllm/utils/lite_profiler/lite_profiler.py index a6ecda29596..ee0b889783e 100644 --- a/vllm/utils/lite_profiler/lite_profiler.py +++ b/vllm/utils/lite_profiler/lite_profiler.py @@ -66,6 +66,10 @@ def _write_log_entry(name: str, elapsed_ns: int) -> None: if directory: os.makedirs(directory, exist_ok=True) # ruff: noqa: SIM115 - intentionally keeping file handle cached globally + # Currently, we are flushing the file handle after every write. This + # is done to ensure safety so that there is no data leakage. + # TODO: We can optimise this further, to ensure performance overhead + # can be reduced in future. _log_file = open(_LOG_PATH, "a", buffering=1) atexit.register(_log_file.close) From b2861e3cc2edd45fe01abf21554d95f96b09913b Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Thu, 16 Oct 2025 12:54:51 -0700 Subject: [PATCH 7/8] Address review comments Signed-off-by: Naman Lalit --- docs/contributing/profiling.md | 6 ++-- vllm/envs.py | 4 +-- vllm/utils/lite_profiler/__init__.py | 4 +-- vllm/utils/lite_profiler/lite_profiler.py | 35 ++++++++++------------- vllm/v1/utils.py | 4 +-- 5 files changed, 23 insertions(+), 30 deletions(-) diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 047165d7073..d84d76aca10 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -242,7 +242,7 @@ The lite profiler is a lightweight, minimal-overhead profiling tool designed for ### How It Works -The lite profiler uses context managers to time code blocks and writes timing data to a log file with minimal overhead. After execution completes, it uses the logs from the log file to generate an aggregate summary report showing time spent in different pipeline stages. Currently, its enabled for both the `vllm bench throughput` and `vllm bench latency` commands. +The lite profiler uses context managers to time code blocks and writes timing data to a log file with minimal overhead. After execution completes, it uses the logs from the log file to generate an aggregated summary report showing time spent in different execution stages. Currently, it's enabled for both the `vllm bench throughput` and `vllm bench latency` commands. ### Environment Variable @@ -270,6 +270,6 @@ vllm bench throughput \ The profiler will automatically generate a summary report at the end showing time breakdowns for: -- **Top-level pipeline events**: Schedule, Model execution, Output processing -- **Model events**: State updates, input preparation, forward pass, postprocessing, bookkeeping +- **Top-level pipeline events**: schedule, model execution, output processing +- **Model events**: state updates, input preparation, forward pass, post-processing, bookkeeping ``` \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index cf727f99db7..804d9bef05a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1388,9 +1388,7 @@ def get_vllm_port() -> int | None: # Log path for the lightweight timing profiler. # If this path is set (not None), lightweight profiling will be enabled, # providing detailed execution latency breakdown by segment. - "VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv( - "VLLM_LITE_PROFILER_LOG_PATH", None - ), + "VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv("VLLM_LITE_PROFILER_LOG_PATH"), } # --8<-- [end:env-vars-definition] diff --git a/vllm/utils/lite_profiler/__init__.py b/vllm/utils/lite_profiler/__init__.py index cdd3687e122..42c3b5085a9 100644 --- a/vllm/utils/lite_profiler/__init__.py +++ b/vllm/utils/lite_profiler/__init__.py @@ -3,11 +3,11 @@ """Lightweight profiler for timing code execution with minimal overhead.""" from vllm.utils.lite_profiler.lite_profiler import ( - LiteScope, + LiteProfilerScope, maybe_emit_lite_profiler_report, ) __all__ = [ - "LiteScope", + "LiteProfilerScope", "maybe_emit_lite_profiler_report", ] diff --git a/vllm/utils/lite_profiler/lite_profiler.py b/vllm/utils/lite_profiler/lite_profiler.py index ee0b889783e..9c47168855a 100644 --- a/vllm/utils/lite_profiler/lite_profiler.py +++ b/vllm/utils/lite_profiler/lite_profiler.py @@ -8,6 +8,7 @@ import multiprocessing import os import time +from functools import cache from types import TracebackType from typing import TextIO @@ -17,6 +18,7 @@ logger = init_logger(__name__) +@cache def _should_log_results() -> bool: """Check if the current process should log results. Only the data-parallel rank 0 engine core and worker 0 should emit logs in @@ -42,8 +44,9 @@ def _write_log_entry(name: str, elapsed_ns: int) -> None: The cached file handle is automatically closed on program exit via atexit. """ global _log_file - _LOG_PATH = envs.VLLM_LITE_PROFILER_LOG_PATH - assert _LOG_PATH is not None + + log_path = envs.VLLM_LITE_PROFILER_LOG_PATH + assert log_path is not None if not _should_log_results(): return @@ -60,23 +63,22 @@ def _write_log_entry(name: str, elapsed_ns: int) -> None: _log_file = None # Write the log entry - log_line = f"{name}|{elapsed_ns}\n" if _log_file is None: - directory = os.path.dirname(_LOG_PATH) + directory = os.path.dirname(log_path) if directory: os.makedirs(directory, exist_ok=True) # ruff: noqa: SIM115 - intentionally keeping file handle cached globally # Currently, we are flushing the file handle after every write. This - # is done to ensure safety so that there is no data leakage. + # is done to ensure safety so that there is no data loss. # TODO: We can optimise this further, to ensure performance overhead # can be reduced in future. - _log_file = open(_LOG_PATH, "a", buffering=1) + _log_file = open(log_path, "a", buffering=1) atexit.register(_log_file.close) - _log_file.write(log_line) + _log_file.write(f"{name}|{elapsed_ns}\n") -class LiteScope: +class LiteProfilerScope: """Lightweight context manager for timing code blocks with minimal overhead. This class provides a simple way to measure and log the execution time of @@ -86,10 +88,10 @@ class LiteScope: def __init__(self, name: str) -> None: self._name = name - self._start_time: int | None = None + self._start_time_ns: int | None = None def __enter__(self) -> None: - self._start_time = time.perf_counter_ns() + self._start_time_ns = time.perf_counter_ns() def __exit__( self, @@ -97,8 +99,8 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - if self._start_time is not None and exc_type is None: - elapsed_ns = time.perf_counter_ns() - self._start_time + if self._start_time_ns is not None: + elapsed_ns = time.perf_counter_ns() - self._start_time_ns _write_log_entry(self._name, elapsed_ns) @@ -112,7 +114,7 @@ def maybe_emit_lite_profiler_report() -> None: """ log_path = envs.VLLM_LITE_PROFILER_LOG_PATH - if log_path is None: + if log_path is None or not os.path.exists(log_path): return # Ensure the log file is flushed and closed before generating report @@ -121,13 +123,6 @@ def maybe_emit_lite_profiler_report() -> None: _log_file.close() _log_file = None - if not os.path.exists(log_path): - logger.warning( - "Lite profiler log not found. Ensure the profiled process sets " - "the expected path." - ) - return - from vllm.utils.lite_profiler import lite_profiler_report logger.info("") diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 00c84ccf3a2..5879bbecc2a 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -26,7 +26,7 @@ from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message from vllm.utils import kill_process_tree -from vllm.utils.lite_profiler.lite_profiler import LiteScope +from vllm.utils.lite_profiler.lite_profiler import LiteProfilerScope from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri if TYPE_CHECKING: @@ -389,7 +389,7 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: func = contextlib.nullcontext if envs.VLLM_LITE_PROFILER_LOG_PATH: - func = LiteScope + func = LiteProfilerScope elif envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: func = record_function elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: From 6b2c0c3d3f9abc591d6d9aaadc3478bd65651bda Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Thu, 16 Oct 2025 12:57:55 -0700 Subject: [PATCH 8/8] use different functools for cache Signed-off-by: Naman Lalit --- vllm/utils/lite_profiler/lite_profiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/utils/lite_profiler/lite_profiler.py b/vllm/utils/lite_profiler/lite_profiler.py index 9c47168855a..6a89498b053 100644 --- a/vllm/utils/lite_profiler/lite_profiler.py +++ b/vllm/utils/lite_profiler/lite_profiler.py @@ -5,10 +5,10 @@ from __future__ import annotations import atexit +import functools import multiprocessing import os import time -from functools import cache from types import TracebackType from typing import TextIO @@ -18,7 +18,7 @@ logger = init_logger(__name__) -@cache +@functools.cache def _should_log_results() -> bool: """Check if the current process should log results. Only the data-parallel rank 0 engine core and worker 0 should emit logs in