Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/benchmarks/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report


def save_to_pytorch_benchmark_format(
Expand Down Expand Up @@ -170,3 +171,6 @@ def run_to_completion(profile_dir: Optional[str] = None):
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

# Generate the lite-profiler report if enabled.
maybe_emit_lite_profiler_report()
4 changes: 4 additions & 0 deletions vllm/benchmarks/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from vllm.benchmarks.lib.ready_checker import wait_for_endpoint
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report

MILLISECONDS_TO_SECONDS_CONVERSION = 1000

Expand Down Expand Up @@ -1443,4 +1444,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)

# Generate the lite-profiler report if enabled.
maybe_emit_lite_profiler_report()

return result_json
4 changes: 4 additions & 0 deletions vllm/benchmarks/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import merge_async_iterators
from vllm.utils.lite_profiler import maybe_emit_lite_profiler_report


def run_vllm(
Expand Down Expand Up @@ -789,3 +790,6 @@ def main(args: argparse.Namespace):
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

# Generate the lite-profiler report if enabled.
maybe_emit_lite_profiler_report()
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_LITE_PROFILER_LOG_PATH: Optional[str] = None


def get_default_cache_root():
Expand Down Expand Up @@ -1391,6 +1392,12 @@ def get_vllm_port() -> Optional[int]:
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Log path for the lightweight timing profiler.
# If this path is set (not None), lightweight profiling will be enabled,
# providing detailed analysis of the execution time for each function call.
"VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv(
"VLLM_LITE_PROFILER_LOG_PATH", None
),
}

# --8<-- [end:env-vars-definition]
Expand Down
144 changes: 144 additions & 0 deletions vllm/utils/lite_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Minimal helpers for opt-in lightweight timing collection."""

from __future__ import annotations

import atexit
import multiprocessing
import os
import time
from contextlib import suppress
from types import TracebackType
from typing import TextIO

import vllm.envs as envs
from vllm.logger import init_logger

logger = init_logger(__name__)


def _should_log_results() -> bool:
"""Check if the current process should log results.
Only the data-parallel rank 0 engine core and worker 0 should emit logs in
multi-process deployments so that we avoid duplicating identical timing
data.
"""
process = multiprocessing.current_process()
return process.name in ("EngineCore_DP0", "VllmWorker-0")


# Cache for log file handle
_log_file: TextIO | None = None
log_results = _should_log_results()


def _write_log_entry(name: str, elapsed_us: int) -> None:
"""Write a profiler entry using cached file handle for optimal performance.
This function implements an efficient caching approach where the file handle
is opened once and reused for all subsequent writes. This eliminates the
significant overhead of opening/closing files for every profiler entry,
which is crucial for maintaining the lightweight nature of the profiler.
The cached file handle is automatically closed on program exit via atexit.
"""
global _log_file
_LOG_PATH = envs.VLLM_LITE_PROFILER_LOG_PATH

if not log_results or _LOG_PATH is None:
return

# Handle case where file handle was opened in parent but we're in the
# child process. The file descriptor may become invalid after fork
if _log_file is not None:
try:
# Verify if the file handle is still valid
_log_file.tell()
except (OSError, ValueError):
# File handle is stale, clear and reopen
_log_file = None

# Write the log entry
log_line = f"{name}|{elapsed_us}\n"
if _log_file is None:
directory = os.path.dirname(_LOG_PATH)
if directory:
os.makedirs(directory, exist_ok=True)
# ruff: noqa: SIM115 - intentionally keeping file handle cached globally
_log_file = open(_LOG_PATH, "a", buffering=50000)
atexit.register(_log_file.close)

_log_file.write(log_line)


class LiteScope:
"""Lightweight context manager for timing code blocks with minimal overhead.
This class provides a simple way to measure and log the execution time of
code blocks using Python's context manager protocol (with statement). It's
designed for high-frequency profiling with minimal performance impact.
"""

def __init__(self, name: str) -> None:
self._name = name
self._start_time: int | None = None

def __enter__(self) -> None:
self._start_time = time.perf_counter_ns()

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
if self._start_time is not None and exc_type is None:
elapsed_ns = time.perf_counter_ns() - self._start_time
# Use integer microseconds for better performance
elapsed_us = elapsed_ns // 1000
_write_log_entry(self._name, elapsed_us)
return False


def maybe_emit_lite_profiler_report() -> None:
"""Generate and display a summary report of profiling data if available.
This function serves as the main entry point for analyzing and displaying
profiling results. It checks if profiling was enabled and a log file exists,
then delegates to the lite_profiler_report module to generate statistics
like function call counts, timing distributions, and performance insights.
"""

log_path = envs.VLLM_LITE_PROFILER_LOG_PATH
if log_path is None:
return

if not os.path.exists(log_path):
logger.warning(
"Lite profiler log not found. Ensure the profiled process sets "
"the expected path."
)
return

try:
from vllm.utils import lite_profiler_report
except Exception as exc: # pragma: no cover - import error should not crash
logger.error("Failed to import lite profiler report helper: %s", exc)
return

logger.info("")
logger.info("Lite profiler summary (%s):", log_path)
try:
# Generate and display the summary report
lite_profiler_report.summarize_log(log_path)

# Clear the log file to avoid accumulating data from multiple runs
with suppress(OSError):
directory = os.path.dirname(log_path)
if directory:
os.makedirs(directory, exist_ok=True)
with open(log_path, "w"):
pass
except Exception as exc: # pragma: no cover - avoid crashing benchmarks
logger.error("Failed to summarize lite profiler log %s: %s", log_path, exc)
121 changes: 121 additions & 0 deletions vllm/utils/lite_profiler_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Summarize a single vLLM lite-profiler log in tabular form.
The script consumes the pipe-separated records emitted by `vllm.lite_profiler`
It expects log lines in the format: "<scope_name>|<elapsed_microseconds>"
"""

from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable, Sequence

from vllm.logger import init_logger

logger = init_logger(__name__)


def _extract_event_us(filename: str) -> dict[str, list[int]]:
"""Collect the microsecond timings for every scope in ``filenames``."""

all_event_us: dict[str, list[int]] = defaultdict(list)
try:
with open(filename, encoding="utf-8") as f:
for raw_line in f:
line = raw_line.strip()
if not line:
continue

# Parse the format: "scope_name|elapsed_microseconds"
if "|" in line:
try:
scope_name, elapsed_us_str = line.split("|", 1)
elapsed_us = int(elapsed_us_str)
all_event_us[scope_name].append(elapsed_us)
except (ValueError, IndexError):
# Skip malformed lines
continue
except FileNotFoundError:
raise FileNotFoundError(f"Lite-profiler log not found: {filename}") from None
return all_event_us


def _sum_events(event_us: dict[str, list[int]]) -> dict[str, int]:
return {event: sum(values) for event, values in event_us.items()}


def _format_duration_us(value_us: int, total_us: int) -> str:
# Convert microseconds to seconds
seconds = value_us / 1e6 if value_us else 0.0
percent = (value_us * 100.0 / total_us) if total_us else 0.0
return f"{seconds:.2f}s ({percent:.2f}%)"


def _render_table(
title: str, headers: Sequence[str], rows: Iterable[Sequence[str]]
) -> None:
table = [list(headers)] + [list(row) for row in rows]
widths = [max(len(row[i]) for row in table) for i in range(len(headers))]

logger.info("")
logger.info(title)
separator = "-" * (sum(widths) + 3 * (len(widths) - 1))
logger.info(separator)

def _fmt(row: Sequence[str]) -> str:
return " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row))

logger.info(_fmt(table[0]))
logger.info(" | ".join("-" * w for w in widths))
for row in table[1:]:
logger.info(_fmt(row))


TOP_EVENTS = [
"Input:Process",
"Step:Schedule",
"Step:Model",
"Step:Output",
]

MODEL_EVENTS = [
"Model:UpdateState",
"Model:PrepareInput",
"Model:Forward",
"Model:Postprocess",
"Model:Sample",
"Model:Bookkeep",
"Model:EPLB",
"Model:Draft",
]


def _compute_table_rows(
event_us_sum: dict[str, int],
events: Sequence[str],
) -> list[str]:
total_us = sum(event_us_sum.get(event, 0) for event in events)
cells = []
for event in events:
cells.append(_format_duration_us(event_us_sum.get(event, 0), total_us))
# Convert microseconds to seconds
total_seconds = total_us / 1_000_000 if total_us else 0.0
cells.append(f"{total_seconds:.2f}s")
return cells


def _print_breakdown_tables(event_us_sum: dict[str, int]) -> None:
for title, events in (
("Top-level pipeline events", TOP_EVENTS),
("Model events breakdown (only includes the main key events)", MODEL_EVENTS),
):
headers = [*events, "TOTAL"]
rows = [_compute_table_rows(event_us_sum, events)]
_render_table(title, headers, rows)


def summarize_log(log_path: str) -> None:
event_us = _extract_event_us(log_path)
event_us_sum = _sum_events(event_us)
_print_breakdown_tables(event_us_sum)
28 changes: 18 additions & 10 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -317,14 +318,19 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output,
)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("Step:Schedule"):
scheduler_output = self.scheduler.schedule()

with record_function_or_nullcontext("Step:Model"):
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output,
)

with record_function_or_nullcontext("Step:Output"):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
) # type: ignore

return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)

Expand Down Expand Up @@ -822,15 +828,17 @@ def _process_input_queue(self):
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
with record_function_or_nullcontext("Input:Process"):
self._handle_client_request(*req)

if waited:
logger.debug("EngineCore loop active.")

# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
with record_function_or_nullcontext("Input:Process"):
self._handle_client_request(*req)

def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
Expand Down
Loading