Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/contributing/profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,41 @@ Leverage VLLM_GC_DEBUG environment variable to debug GC costs.
- VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times
- VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger to log top 5
collected objects for each gc.collect

## Lite Profiler

The lite profiler is a lightweight, minimal-overhead profiling tool designed for capturing the time spent by the critical components of the system.

### How It Works

The lite profiler uses context managers to time code blocks and writes timing data to a log file with minimal overhead (~50KB buffer). After execution completes, it generates an aggregate summary report showing time spent in different pipeline stages.

### Environment Variable

Enable lite profiling by setting the `VLLM_LITE_PROFILER_LOG_PATH` environment variable to the desired log file path:

```bash
export VLLM_LITE_PROFILER_LOG_PATH=/tmp/vllm_lite_profile.log
```

### Example Commands and Usage

#### Throughput Benchmark

Profile throughput with the lite profiler:

```bash
VLLM_LITE_PROFILER_LOG_PATH=/tmp/vllm_lite_profile.log \
vllm bench throughput \
--model meta-llama/Llama-3.1-8B-Instruct \
--dataset-name random \
--num-prompts 1000 \
--input-len 128 \
--output-len 128
```

The profiler will automatically generate a summary report at the end showing time breakdowns for:

- **Top-level pipeline events**: Schedule, Model execution, Output processing
- **Model events**: State updates, input preparation, forward pass, postprocessing, bookkeeping
```
4 changes: 4 additions & 0 deletions vllm/benchmarks/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
from vllm.utils.lite_profiler.lite_profiler import maybe_emit_lite_profiler_report


def save_to_pytorch_benchmark_format(
Expand Down Expand Up @@ -170,3 +171,6 @@ def run_to_completion(profile_dir: Optional[str] = None):
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

# Generate the lite-profiler report if enabled.
maybe_emit_lite_profiler_report()
4 changes: 4 additions & 0 deletions vllm/benchmarks/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import merge_async_iterators
from vllm.utils.lite_profiler.lite_profiler import maybe_emit_lite_profiler_report


def run_vllm(
Expand Down Expand Up @@ -789,3 +790,6 @@ def main(args: argparse.Namespace):
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

# Generate the lite-profiler report if enabled.
maybe_emit_lite_profiler_report()
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_LITE_PROFILER_LOG_PATH: Optional[str] = None


def get_default_cache_root():
Expand Down Expand Up @@ -1391,6 +1392,12 @@ def get_vllm_port() -> Optional[int]:
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Log path for the lightweight timing profiler.
# If this path is set (not None), lightweight profiling will be enabled,
# providing detailed execution latency breakdown by segment.
"VLLM_LITE_PROFILER_LOG_PATH": lambda: os.getenv(
"VLLM_LITE_PROFILER_LOG_PATH", None
),
}

# --8<-- [end:env-vars-definition]
Expand Down
137 changes: 137 additions & 0 deletions vllm/utils/lite_profiler/lite_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Minimal helpers for opt-in lightweight timing collection."""

from __future__ import annotations

import atexit
import multiprocessing
import os
import time
from types import TracebackType
from typing import TextIO

import vllm.envs as envs
from vllm.logger import init_logger

logger = init_logger(__name__)


def _should_log_results() -> bool:
"""Check if the current process should log results.
Only the data-parallel rank 0 engine core and worker 0 should emit logs in
multi-process deployments so that we avoid duplicating identical timing
data.
"""
process = multiprocessing.current_process()
return process.name in ("EngineCore_DP0", "VllmWorker-0")


# Cache for log file handle
_log_file: TextIO | None = None


def _write_log_entry(name: str, elapsed_ns: int) -> None:
"""Write a profiler entry using cached file handle for optimal performance.

This function implements an efficient caching approach where the file handle
is opened once and reused for all subsequent writes. This eliminates the
significant overhead of opening/closing files for every profiler entry,
which is crucial for maintaining the lightweight nature of the profiler.

The cached file handle is automatically closed on program exit via atexit.
"""
global _log_file
_LOG_PATH = envs.VLLM_LITE_PROFILER_LOG_PATH
assert _LOG_PATH is not None

if not _should_log_results():
return

# Handle case where file handle was opened in parent but we're in the
# child process. The file descriptor may become invalid after fork
if _log_file is not None:
try:
# Verify if the file handle is still valid
_log_file.flush()
_log_file.tell()
except (OSError, ValueError):
# File handle is stale, clear and reopen
_log_file = None

# Write the log entry
log_line = f"{name}|{elapsed_ns}\n"
if _log_file is None:
directory = os.path.dirname(_LOG_PATH)
if directory:
os.makedirs(directory, exist_ok=True)
# ruff: noqa: SIM115 - intentionally keeping file handle cached globally
_log_file = open(_LOG_PATH, "a", buffering=50000)
atexit.register(_log_file.close)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only closes the file when Python interpreter shuts down. This makes me think whether it's possible for us to read incomplete data if we don't shut down the server (due to buffered data not being flushed yet).

This is a bit tricky. One way is to explicitly close in maybe_emit_lite_profiler_report, but this only works for benchmark_latency and benchmark_throughput. benchmark serve runs server and client separately.

If we can't find a good way, we can skip benchmark serve for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try to flush the file data before going for generating the summary, but I think that needs a bit of testing for serving part.
I'll skip the serving part in this PR, and will take it up a new PR with a few more changes.


_log_file.write(log_line)


class LiteScope:
"""Lightweight context manager for timing code blocks with minimal overhead.

This class provides a simple way to measure and log the execution time of
code blocks using Python's context manager protocol (with statement). It's
designed for high-frequency profiling with minimal performance impact.
"""

def __init__(self, name: str) -> None:
self._name = name
self._start_time: int | None = None

def __enter__(self) -> None:
self._start_time = time.perf_counter_ns()

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if self._start_time is not None and exc_type is None:
elapsed_ns = time.perf_counter_ns() - self._start_time
_write_log_entry(self._name, elapsed_ns)


def maybe_emit_lite_profiler_report() -> None:
"""Generate and display a summary report of profiling data if available.

This function serves as the main entry point for analyzing and displaying
profiling results. It checks if profiling was enabled and a log file exists,
then delegates to the lite_profiler_report module to generate statistics
like function call counts, timing distributions, and performance insights.
"""

log_path = envs.VLLM_LITE_PROFILER_LOG_PATH
if log_path is None:
return

# Ensure the log file is flushed and closed before generating report
global _log_file
if _log_file is not None:
_log_file.flush()
_log_file.close()
_log_file = None

if not os.path.exists(log_path):
logger.warning(
"Lite profiler log not found. Ensure the profiled process sets "
"the expected path."
)
return

from vllm.utils.lite_profiler import lite_profiler_report

logger.info("")
logger.info("Lite profiler summary (%s):", log_path)
try:
# Generate and display the summary report
lite_profiler_report.summarize_log(log_path)
os.remove(log_path)
except Exception as exc: # pragma: no cover - avoid crashing benchmarks
logger.error("Failed to summarize lite profiler log %s: %s", log_path, exc)
115 changes: 115 additions & 0 deletions vllm/utils/lite_profiler/lite_profiler_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Summarize a single vLLM lite-profiler log in tabular form.
The script consumes the pipe-separated records emitted by `vllm.lite_profiler`
It expects log lines in the format: "<scope_name>|<elapsed_microseconds>"
"""

from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable, Sequence

from vllm.logger import init_logger

logger = init_logger(__name__)


def _extract_event_us(filename: str) -> dict[str, list[int]]:
"""Collect the nanosecond timings for every scope in ``filename``."""

all_event_ns: dict[str, list[int]] = defaultdict(list)
try:
with open(filename, encoding="utf-8") as f:
for raw_line in f:
line = raw_line.strip()
if not line:
continue

# Parse the format: "scope_name|elapsed_nanoseconds"
if "|" in line:
try:
scope_name, elapsed_ns_str = line.split("|", 1)
elapsed_ns = int(elapsed_ns_str)
all_event_ns[scope_name].append(elapsed_ns)
except (ValueError, IndexError):
# Skip malformed lines
continue
except FileNotFoundError:
raise FileNotFoundError(f"Lite-profiler log not found: {filename}") from None
return all_event_ns


def _sum_events(event_ns: dict[str, list[int]]) -> dict[str, int]:
return {event: sum(values) for event, values in event_ns.items()}


def _format_duration_ns(value_ns: int, total_ns: int) -> str:
seconds = value_ns / 1e9 if value_ns else 0.0
percent = (value_ns * 100.0 / total_ns) if total_ns else 0.0
return f"{seconds:.3f}s ({percent:.2f}%)"


def _render_table(
title: str, headers: Sequence[str], rows: Iterable[Sequence[str]]
) -> None:
table = [list(headers)] + [list(row) for row in rows]
widths = [max(len(row[i]) for row in table) for i in range(len(headers))]

logger.info("")
logger.info(title)
separator = "-" * (sum(widths) + 3 * (len(widths) - 1))
logger.info(separator)

def _fmt(row: Sequence[str]) -> str:
return " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row))

logger.info(_fmt(table[0]))
logger.info(" | ".join("-" * w for w in widths))
for row in table[1:]:
logger.info(_fmt(row))


TOP_EVENTS = [
"Step:Schedule",
"Step:Model",
"Step:Output",
]

MODEL_EVENTS = [
"Model:UpdateState",
"Model:PrepareInput",
"Model:Forward",
"Model:Postprocess",
"Model:Bookkeep",
]
Comment on lines +74 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None blocking: Should we add extract and accumulate the events automatically during line processing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, I didn't fully get this, could you please elaborate on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When analyzing the log file, we should be able to collect all logged events instead of specify the event here. So in the future, when people add new events to the code, they don't need to update lite_profiler_report to add the new events.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Thank for explaining.
Actually right now, we are collecting all the metrics in the log file which are referenced by the record_function() in the vllm code, and for different report generation components (like model and step breakdown), we are just picking up the specific events mentioned in this list.
But, I think we can update our approach later on to being closer to the one that you mentioned.



def _compute_table_rows(
event_ns_sum: dict[str, int],
events: Sequence[str],
) -> list[str]:
total_ns = sum(event_ns_sum.get(event, 0) for event in events)
cells = []
for event in events:
cells.append(_format_duration_ns(event_ns_sum.get(event, 0), total_ns))
total_seconds = total_ns / 1e9 if total_ns else 0.0
cells.append(f"{total_seconds:.3f}s")
return cells


def _print_breakdown_tables(event_ns_sum: dict[str, int]) -> None:
for title, events in (
("Top-level pipeline events", TOP_EVENTS),
("Model events breakdown (only includes the main key events)", MODEL_EVENTS),
):
headers = [*events, "TOTAL"]
rows = [_compute_table_rows(event_ns_sum, events)]
_render_table(title, headers, rows)


def summarize_log(log_path: str) -> None:
elapsed_ns = _extract_event_us(log_path)
event_ns_sum = _sum_events(elapsed_ns)
_print_breakdown_tables(event_ns_sum)
22 changes: 14 additions & 8 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -317,14 +318,19 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output,
)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("Step:Schedule"):
scheduler_output = self.scheduler.schedule()

with record_function_or_nullcontext("Step:Model"):
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output,
)

with record_function_or_nullcontext("Step:Output"):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
) # type: ignore

return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)

Expand Down
Loading