Skip to content

Commit c1a565a

Browse files
njhillskyloevil
authored andcommitted
[CI] Fail subprocess tests with root-cause error (vllm-project#23795)
Signed-off-by: Nick Hill <[email protected]>
1 parent 29ead89 commit c1a565a

File tree

6 files changed

+138
-33
lines changed

6 files changed

+138
-33
lines changed

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli
2121
sentence-transformers # required for embedding tests
2222
soundfile # required for audio tests
2323
jiwer # required for audio tests
24+
tblib # for pickling test exceptions
2425
timm >=1.0.17 # required for internvl and gemma3n-mm test
2526
torch==2.8.0
2627
torchaudio==2.8.0

requirements/test.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ contourpy==1.3.0
137137
# via matplotlib
138138
cramjam==2.9.0
139139
# via fastparquet
140-
cupy-cuda12x==13.3.0
140+
cupy-cuda12x==13.6.0
141141
# via ray
142142
cycler==0.12.1
143143
# via matplotlib
@@ -1032,6 +1032,8 @@ tabledata==1.3.3
10321032
# via pytablewriter
10331033
tabulate==0.9.0
10341034
# via sacrebleu
1035+
tblib==3.1.0
1036+
# via -r requirements/test.in
10351037
tcolorpy==0.1.6
10361038
# via pytablewriter
10371039
tenacity==9.0.0

tests/async_engine/test_api_server.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import copyreg
45
import os
56
import subprocess
67
import sys
@@ -10,6 +11,30 @@
1011

1112
import pytest
1213
import requests
14+
import urllib3.exceptions
15+
16+
17+
def _pickle_new_connection_error(obj):
18+
"""Custom pickler for NewConnectionError to fix tblib compatibility."""
19+
# Extract the original message by removing the "conn: " prefix
20+
full_message = obj.args[0] if obj.args else ""
21+
if ': ' in full_message:
22+
# Split off the connection part and keep the actual message
23+
_, actual_message = full_message.split(': ', 1)
24+
else:
25+
actual_message = full_message
26+
return _unpickle_new_connection_error, (actual_message, )
27+
28+
29+
def _unpickle_new_connection_error(message):
30+
"""Custom unpickler for NewConnectionError."""
31+
# Create with None as conn and the actual message
32+
return urllib3.exceptions.NewConnectionError(None, message)
33+
34+
35+
# Register the custom pickle/unpickle functions for tblib compatibility
36+
copyreg.pickle(urllib3.exceptions.NewConnectionError,
37+
_pickle_new_connection_error)
1338

1439

1540
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
@@ -52,6 +77,7 @@ def api_server(distributed_executor_backend: str):
5277
uvicorn_process.terminate()
5378

5479

80+
@pytest.mark.timeout(300)
5581
@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"])
5682
def test_api_server(api_server, distributed_executor_backend: str):
5783
"""

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# ruff: noqa
5+
6+
from tblib import pickling_support
7+
8+
# Install support for pickling exceptions so that we can nicely propagate
9+
# failures from tests running in a subprocess.
10+
# This should be run before any custom exception subclasses are defined.
11+
pickling_support.install()
12+
313
import http.server
414
import json
515
import math

tests/utils.py

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import asyncio
5+
import contextlib
56
import copy
67
import functools
78
import importlib
@@ -13,7 +14,7 @@
1314
import tempfile
1415
import time
1516
import warnings
16-
from contextlib import contextmanager, suppress
17+
from contextlib import ExitStack, contextmanager, suppress
1718
from multiprocessing import Process
1819
from pathlib import Path
1920
from typing import Any, Callable, Literal, Optional, Union
@@ -800,43 +801,106 @@ def wait_for_gpu_memory_to_clear(*,
800801

801802

802803
def fork_new_process_for_each_test(
803-
f: Callable[_P, None]) -> Callable[_P, None]:
804+
func: Callable[_P, None]) -> Callable[_P, None]:
804805
"""Decorator to fork a new process for each test function.
805806
See https://github.com/vllm-project/vllm/issues/7053 for more details.
806807
"""
807808

808-
@functools.wraps(f)
809+
@functools.wraps(func)
809810
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
810811
# Make the process the leader of its own process group
811812
# to avoid sending SIGTERM to the parent process
812813
os.setpgrp()
813814
from _pytest.outcomes import Skipped
814-
pid = os.fork()
815-
print(f"Fork a new process to run a test {pid}")
816-
if pid == 0:
817-
try:
818-
f(*args, **kwargs)
819-
except Skipped as e:
820-
# convert Skipped to exit code 0
821-
print(str(e))
822-
os._exit(0)
823-
except Exception:
824-
import traceback
825-
traceback.print_exc()
826-
os._exit(1)
815+
816+
# Create a unique temporary file to store exception info from child
817+
# process. Use test function name and process ID to avoid collisions.
818+
with tempfile.NamedTemporaryFile(
819+
delete=False,
820+
mode='w+b',
821+
prefix=f"vllm_test_{func.__name__}_{os.getpid()}_",
822+
suffix=".exc") as exc_file, ExitStack() as delete_after:
823+
exc_file_path = exc_file.name
824+
delete_after.callback(os.remove, exc_file_path)
825+
826+
pid = os.fork()
827+
print(f"Fork a new process to run a test {pid}")
828+
if pid == 0:
829+
# Parent process responsible for deleting, don't delete
830+
# in child.
831+
delete_after.pop_all()
832+
try:
833+
func(*args, **kwargs)
834+
except Skipped as e:
835+
# convert Skipped to exit code 0
836+
print(str(e))
837+
os._exit(0)
838+
except Exception as e:
839+
import traceback
840+
tb_string = traceback.format_exc()
841+
842+
# Try to serialize the exception object first
843+
exc_to_serialize: dict[str, Any]
844+
try:
845+
# First, try to pickle the actual exception with
846+
# its traceback.
847+
exc_to_serialize = {'pickled_exception': e}
848+
# Test if it can be pickled
849+
cloudpickle.dumps(exc_to_serialize)
850+
except (Exception, KeyboardInterrupt):
851+
# Fall back to string-based approach.
852+
exc_to_serialize = {
853+
'exception_type': type(e).__name__,
854+
'exception_msg': str(e),
855+
'traceback': tb_string,
856+
}
857+
try:
858+
with open(exc_file_path, 'wb') as f:
859+
cloudpickle.dump(exc_to_serialize, f)
860+
except Exception:
861+
# Fallback: just print the traceback.
862+
print(tb_string)
863+
os._exit(1)
864+
else:
865+
os._exit(0)
827866
else:
828-
os._exit(0)
829-
else:
830-
pgid = os.getpgid(pid)
831-
_pid, _exitcode = os.waitpid(pid, 0)
832-
# ignore SIGTERM signal itself
833-
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
834-
# kill all child processes
835-
os.killpg(pgid, signal.SIGTERM)
836-
# restore the signal handler
837-
signal.signal(signal.SIGTERM, old_signal_handler)
838-
assert _exitcode == 0, (f"function {f} failed when called with"
839-
f" args {args} and kwargs {kwargs}")
867+
pgid = os.getpgid(pid)
868+
_pid, _exitcode = os.waitpid(pid, 0)
869+
# ignore SIGTERM signal itself
870+
old_signal_handler = signal.signal(signal.SIGTERM,
871+
signal.SIG_IGN)
872+
# kill all child processes
873+
os.killpg(pgid, signal.SIGTERM)
874+
# restore the signal handler
875+
signal.signal(signal.SIGTERM, old_signal_handler)
876+
if _exitcode != 0:
877+
# Try to read the exception from the child process
878+
exc_info = {}
879+
if os.path.exists(exc_file_path):
880+
with contextlib.suppress(Exception), \
881+
open(exc_file_path, 'rb') as f:
882+
exc_info = cloudpickle.load(f)
883+
884+
if (original_exception :=
885+
exc_info.get('pickled_exception')) is not None:
886+
# Re-raise the actual exception object if it was
887+
# successfully pickled.
888+
assert isinstance(original_exception, Exception)
889+
raise original_exception
890+
891+
if (original_tb := exc_info.get("traceback")) is not None:
892+
# Use string-based traceback for fallback case
893+
raise AssertionError(
894+
f"Test {func.__name__} failed when called with"
895+
f" args {args} and kwargs {kwargs}"
896+
f" (exit code: {_exitcode}):\n{original_tb}"
897+
) from None
898+
899+
# Fallback to the original generic error
900+
raise AssertionError(
901+
f"function {func.__name__} failed when called with"
902+
f" args {args} and kwargs {kwargs}"
903+
f" (exit code: {_exitcode})") from None
840904

841905
return wrapper
842906

vllm/executor/ray_distributed_executor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,12 @@ def _init_executor(self) -> None:
117117
self.driver_worker.execute_method)
118118

119119
def shutdown(self) -> None:
120-
logger.info(
121-
"Shutting down Ray distributed executor. If you see error log "
122-
"from logging.cc regarding SIGTERM received, please ignore because "
123-
"this is the expected termination process in Ray.")
120+
if logger:
121+
# Somehow logger can be None here.
122+
logger.info(
123+
"Shutting down Ray distributed executor. If you see error log "
124+
"from logging.cc regarding SIGTERM received, please ignore "
125+
"because this is the expected termination process in Ray.")
124126
if hasattr(self, "forward_dag") and self.forward_dag is not None:
125127
self.forward_dag.teardown()
126128
import ray

0 commit comments

Comments
 (0)