-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Core] Lite weight profiler implementation #26648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
314d2fa
be66fef
8dbe58a
abee1ae
910dd31
e2ced02
3b8dd27
07dfcc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
namanlalitnyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
"""Minimal helpers for opt-in lightweight timing collection.""" | ||
|
||
from __future__ import annotations | ||
|
||
import atexit | ||
import multiprocessing | ||
import os | ||
import time | ||
from types import TracebackType | ||
from typing import TextIO | ||
|
||
import vllm.envs as envs | ||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
def _should_log_results() -> bool: | ||
namanlalitnyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Check if the current process should log results. | ||
Only the data-parallel rank 0 engine core and worker 0 should emit logs in | ||
multi-process deployments so that we avoid duplicating identical timing | ||
data. | ||
""" | ||
process = multiprocessing.current_process() | ||
return process.name in ("EngineCore_DP0", "VllmWorker-0") | ||
namanlalitnyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# Cache for log file handle | ||
_log_file: TextIO | None = None | ||
namanlalitnyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _write_log_entry(name: str, elapsed_ns: int) -> None: | ||
"""Write a profiler entry using cached file handle for optimal performance. | ||
|
||
This function implements an efficient caching approach where the file handle | ||
is opened once and reused for all subsequent writes. This eliminates the | ||
significant overhead of opening/closing files for every profiler entry, | ||
which is crucial for maintaining the lightweight nature of the profiler. | ||
|
||
The cached file handle is automatically closed on program exit via atexit. | ||
""" | ||
global _log_file | ||
_LOG_PATH = envs.VLLM_LITE_PROFILER_LOG_PATH | ||
namanlalitnyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
assert _LOG_PATH is not None | ||
|
||
if not _should_log_results(): | ||
return | ||
|
||
# Handle case where file handle was opened in parent but we're in the | ||
# child process. The file descriptor may become invalid after fork | ||
if _log_file is not None: | ||
try: | ||
# Verify if the file handle is still valid | ||
_log_file.flush() | ||
_log_file.tell() | ||
except (OSError, ValueError): | ||
# File handle is stale, clear and reopen | ||
_log_file = None | ||
|
||
# Write the log entry | ||
log_line = f"{name}|{elapsed_ns}\n" | ||
namanlalitnyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if _log_file is None: | ||
directory = os.path.dirname(_LOG_PATH) | ||
if directory: | ||
os.makedirs(directory, exist_ok=True) | ||
# ruff: noqa: SIM115 - intentionally keeping file handle cached globally | ||
_log_file = open(_LOG_PATH, "a", buffering=50000) | ||
atexit.register(_log_file.close) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only closes the file when Python interpreter shuts down. This makes me think whether it's possible for us to read incomplete data if we don't shut down the server (due to buffered data not being flushed yet). This is a bit tricky. One way is to explicitly close in If we can't find a good way, we can skip There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can try to flush the file data before going for generating the summary, but I think that needs a bit of testing for serving part.
namanlalitnyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
_log_file.write(log_line) | ||
|
||
|
||
class LiteScope: | ||
namanlalitnyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"""Lightweight context manager for timing code blocks with minimal overhead. | ||
|
||
This class provides a simple way to measure and log the execution time of | ||
code blocks using Python's context manager protocol (with statement). It's | ||
designed for high-frequency profiling with minimal performance impact. | ||
""" | ||
|
||
def __init__(self, name: str) -> None: | ||
self._name = name | ||
self._start_time: int | None = None | ||
namanlalitnyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def __enter__(self) -> None: | ||
self._start_time = time.perf_counter_ns() | ||
|
||
def __exit__( | ||
self, | ||
exc_type: type[BaseException] | None, | ||
exc_value: BaseException | None, | ||
traceback: TracebackType | None, | ||
) -> None: | ||
if self._start_time is not None and exc_type is None: | ||
namanlalitnyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
elapsed_ns = time.perf_counter_ns() - self._start_time | ||
_write_log_entry(self._name, elapsed_ns) | ||
|
||
|
||
def maybe_emit_lite_profiler_report() -> None: | ||
"""Generate and display a summary report of profiling data if available. | ||
|
||
This function serves as the main entry point for analyzing and displaying | ||
profiling results. It checks if profiling was enabled and a log file exists, | ||
then delegates to the lite_profiler_report module to generate statistics | ||
like function call counts, timing distributions, and performance insights. | ||
""" | ||
|
||
log_path = envs.VLLM_LITE_PROFILER_LOG_PATH | ||
if log_path is None: | ||
return | ||
|
||
# Ensure the log file is flushed and closed before generating report | ||
global _log_file | ||
if _log_file is not None: | ||
_log_file.flush() | ||
_log_file.close() | ||
_log_file = None | ||
|
||
if not os.path.exists(log_path): | ||
logger.warning( | ||
"Lite profiler log not found. Ensure the profiled process sets " | ||
"the expected path." | ||
) | ||
return | ||
namanlalitnyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
from vllm.utils.lite_profiler import lite_profiler_report | ||
|
||
logger.info("") | ||
logger.info("Lite profiler summary (%s):", log_path) | ||
try: | ||
# Generate and display the summary report | ||
lite_profiler_report.summarize_log(log_path) | ||
os.remove(log_path) | ||
namanlalitnyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except Exception as exc: # pragma: no cover - avoid crashing benchmarks | ||
logger.error("Failed to summarize lite profiler log %s: %s", log_path, exc) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
"""Summarize a single vLLM lite-profiler log in tabular form. | ||
The script consumes the pipe-separated records emitted by `vllm.lite_profiler` | ||
It expects log lines in the format: "<scope_name>|<elapsed_microseconds>" | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from collections.abc import Iterable, Sequence | ||
|
||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
def _extract_event_us(filename: str) -> dict[str, list[int]]: | ||
"""Collect the nanosecond timings for every scope in ``filename``.""" | ||
|
||
all_event_ns: dict[str, list[int]] = defaultdict(list) | ||
try: | ||
with open(filename, encoding="utf-8") as f: | ||
for raw_line in f: | ||
line = raw_line.strip() | ||
if not line: | ||
continue | ||
|
||
# Parse the format: "scope_name|elapsed_nanoseconds" | ||
if "|" in line: | ||
try: | ||
scope_name, elapsed_ns_str = line.split("|", 1) | ||
elapsed_ns = int(elapsed_ns_str) | ||
all_event_ns[scope_name].append(elapsed_ns) | ||
except (ValueError, IndexError): | ||
# Skip malformed lines | ||
continue | ||
except FileNotFoundError: | ||
raise FileNotFoundError(f"Lite-profiler log not found: {filename}") from None | ||
return all_event_ns | ||
|
||
|
||
def _sum_events(event_ns: dict[str, list[int]]) -> dict[str, int]: | ||
return {event: sum(values) for event, values in event_ns.items()} | ||
|
||
|
||
def _format_duration_ns(value_ns: int, total_ns: int) -> str: | ||
seconds = value_ns / 1e9 if value_ns else 0.0 | ||
percent = (value_ns * 100.0 / total_ns) if total_ns else 0.0 | ||
return f"{seconds:.3f}s ({percent:.2f}%)" | ||
|
||
|
||
def _render_table( | ||
title: str, headers: Sequence[str], rows: Iterable[Sequence[str]] | ||
) -> None: | ||
table = [list(headers)] + [list(row) for row in rows] | ||
widths = [max(len(row[i]) for row in table) for i in range(len(headers))] | ||
|
||
logger.info("") | ||
logger.info(title) | ||
separator = "-" * (sum(widths) + 3 * (len(widths) - 1)) | ||
logger.info(separator) | ||
|
||
def _fmt(row: Sequence[str]) -> str: | ||
return " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) | ||
|
||
logger.info(_fmt(table[0])) | ||
logger.info(" | ".join("-" * w for w in widths)) | ||
for row in table[1:]: | ||
logger.info(_fmt(row)) | ||
|
||
|
||
TOP_EVENTS = [ | ||
"Step:Schedule", | ||
"Step:Model", | ||
"Step:Output", | ||
] | ||
|
||
MODEL_EVENTS = [ | ||
"Model:UpdateState", | ||
"Model:PrepareInput", | ||
"Model:Forward", | ||
"Model:Postprocess", | ||
"Model:Bookkeep", | ||
] | ||
Comment on lines
+74
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None blocking: Should we add extract and accumulate the events automatically during line processing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, sorry, I didn't fully get this, could you please elaborate on it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When analyzing the log file, we should be able to collect all logged events instead of specify the event here. So in the future, when people add new events to the code, they don't need to update lite_profiler_report to add the new events. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see. Thank for explaining. |
||
|
||
|
||
def _compute_table_rows( | ||
event_ns_sum: dict[str, int], | ||
events: Sequence[str], | ||
) -> list[str]: | ||
total_ns = sum(event_ns_sum.get(event, 0) for event in events) | ||
cells = [] | ||
for event in events: | ||
cells.append(_format_duration_ns(event_ns_sum.get(event, 0), total_ns)) | ||
total_seconds = total_ns / 1e9 if total_ns else 0.0 | ||
cells.append(f"{total_seconds:.3f}s") | ||
return cells | ||
|
||
|
||
def _print_breakdown_tables(event_ns_sum: dict[str, int]) -> None: | ||
for title, events in ( | ||
("Top-level pipeline events", TOP_EVENTS), | ||
("Model events breakdown (only includes the main key events)", MODEL_EVENTS), | ||
): | ||
headers = [*events, "TOTAL"] | ||
rows = [_compute_table_rows(event_ns_sum, events)] | ||
_render_table(title, headers, rows) | ||
|
||
|
||
def summarize_log(log_path: str) -> None: | ||
elapsed_ns = _extract_event_us(log_path) | ||
event_ns_sum = _sum_events(elapsed_ns) | ||
_print_breakdown_tables(event_ns_sum) |
Uh oh!
There was an error while loading. Please reload this page.