|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | 4 | import asyncio |
| 5 | +import contextlib |
5 | 6 | import copy |
6 | 7 | import functools |
7 | 8 | import importlib |
|
13 | 14 | import tempfile |
14 | 15 | import time |
15 | 16 | import warnings |
16 | | -from contextlib import contextmanager, suppress |
| 17 | +from contextlib import ExitStack, contextmanager, suppress |
17 | 18 | from multiprocessing import Process |
18 | 19 | from pathlib import Path |
19 | 20 | from typing import Any, Callable, Literal, Optional, Union |
@@ -800,43 +801,106 @@ def wait_for_gpu_memory_to_clear(*, |
800 | 801 |
|
801 | 802 |
|
802 | 803 | def fork_new_process_for_each_test( |
803 | | - f: Callable[_P, None]) -> Callable[_P, None]: |
| 804 | + func: Callable[_P, None]) -> Callable[_P, None]: |
804 | 805 | """Decorator to fork a new process for each test function. |
805 | 806 | See https://github.com/vllm-project/vllm/issues/7053 for more details. |
806 | 807 | """ |
807 | 808 |
|
808 | | - @functools.wraps(f) |
| 809 | + @functools.wraps(func) |
809 | 810 | def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: |
810 | 811 | # Make the process the leader of its own process group |
811 | 812 | # to avoid sending SIGTERM to the parent process |
812 | 813 | os.setpgrp() |
813 | 814 | from _pytest.outcomes import Skipped |
814 | | - pid = os.fork() |
815 | | - print(f"Fork a new process to run a test {pid}") |
816 | | - if pid == 0: |
817 | | - try: |
818 | | - f(*args, **kwargs) |
819 | | - except Skipped as e: |
820 | | - # convert Skipped to exit code 0 |
821 | | - print(str(e)) |
822 | | - os._exit(0) |
823 | | - except Exception: |
824 | | - import traceback |
825 | | - traceback.print_exc() |
826 | | - os._exit(1) |
| 815 | + |
| 816 | + # Create a unique temporary file to store exception info from child |
| 817 | + # process. Use test function name and process ID to avoid collisions. |
| 818 | + with tempfile.NamedTemporaryFile( |
| 819 | + delete=False, |
| 820 | + mode='w+b', |
| 821 | + prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", |
| 822 | + suffix=".exc") as exc_file, ExitStack() as delete_after: |
| 823 | + exc_file_path = exc_file.name |
| 824 | + delete_after.callback(os.remove, exc_file_path) |
| 825 | + |
| 826 | + pid = os.fork() |
| 827 | + print(f"Fork a new process to run a test {pid}") |
| 828 | + if pid == 0: |
| 829 | + # Parent process responsible for deleting, don't delete |
| 830 | + # in child. |
| 831 | + delete_after.pop_all() |
| 832 | + try: |
| 833 | + func(*args, **kwargs) |
| 834 | + except Skipped as e: |
| 835 | + # convert Skipped to exit code 0 |
| 836 | + print(str(e)) |
| 837 | + os._exit(0) |
| 838 | + except Exception as e: |
| 839 | + import traceback |
| 840 | + tb_string = traceback.format_exc() |
| 841 | + |
| 842 | + # Try to serialize the exception object first |
| 843 | + exc_to_serialize: dict[str, Any] |
| 844 | + try: |
| 845 | + # First, try to pickle the actual exception with |
| 846 | + # its traceback. |
| 847 | + exc_to_serialize = {'pickled_exception': e} |
| 848 | + # Test if it can be pickled |
| 849 | + cloudpickle.dumps(exc_to_serialize) |
| 850 | + except (Exception, KeyboardInterrupt): |
| 851 | + # Fall back to string-based approach. |
| 852 | + exc_to_serialize = { |
| 853 | + 'exception_type': type(e).__name__, |
| 854 | + 'exception_msg': str(e), |
| 855 | + 'traceback': tb_string, |
| 856 | + } |
| 857 | + try: |
| 858 | + with open(exc_file_path, 'wb') as f: |
| 859 | + cloudpickle.dump(exc_to_serialize, f) |
| 860 | + except Exception: |
| 861 | + # Fallback: just print the traceback. |
| 862 | + print(tb_string) |
| 863 | + os._exit(1) |
| 864 | + else: |
| 865 | + os._exit(0) |
827 | 866 | else: |
828 | | - os._exit(0) |
829 | | - else: |
830 | | - pgid = os.getpgid(pid) |
831 | | - _pid, _exitcode = os.waitpid(pid, 0) |
832 | | - # ignore SIGTERM signal itself |
833 | | - old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) |
834 | | - # kill all child processes |
835 | | - os.killpg(pgid, signal.SIGTERM) |
836 | | - # restore the signal handler |
837 | | - signal.signal(signal.SIGTERM, old_signal_handler) |
838 | | - assert _exitcode == 0, (f"function {f} failed when called with" |
839 | | - f" args {args} and kwargs {kwargs}") |
| 867 | + pgid = os.getpgid(pid) |
| 868 | + _pid, _exitcode = os.waitpid(pid, 0) |
| 869 | + # ignore SIGTERM signal itself |
| 870 | + old_signal_handler = signal.signal(signal.SIGTERM, |
| 871 | + signal.SIG_IGN) |
| 872 | + # kill all child processes |
| 873 | + os.killpg(pgid, signal.SIGTERM) |
| 874 | + # restore the signal handler |
| 875 | + signal.signal(signal.SIGTERM, old_signal_handler) |
| 876 | + if _exitcode != 0: |
| 877 | + # Try to read the exception from the child process |
| 878 | + exc_info = {} |
| 879 | + if os.path.exists(exc_file_path): |
| 880 | + with contextlib.suppress(Exception), \ |
| 881 | + open(exc_file_path, 'rb') as f: |
| 882 | + exc_info = cloudpickle.load(f) |
| 883 | + |
| 884 | + if (original_exception := |
| 885 | + exc_info.get('pickled_exception')) is not None: |
| 886 | + # Re-raise the actual exception object if it was |
| 887 | + # successfully pickled. |
| 888 | + assert isinstance(original_exception, Exception) |
| 889 | + raise original_exception |
| 890 | + |
| 891 | + if (original_tb := exc_info.get("traceback")) is not None: |
| 892 | + # Use string-based traceback for fallback case |
| 893 | + raise AssertionError( |
| 894 | + f"Test {func.__name__} failed when called with" |
| 895 | + f" args {args} and kwargs {kwargs}" |
| 896 | + f" (exit code: {_exitcode}):\n{original_tb}" |
| 897 | + ) from None |
| 898 | + |
| 899 | + # Fallback to the original generic error |
| 900 | + raise AssertionError( |
| 901 | + f"function {func.__name__} failed when called with" |
| 902 | + f" args {args} and kwargs {kwargs}" |
| 903 | + f" (exit code: {_exitcode})") from None |
840 | 904 |
|
841 | 905 | return wrapper |
842 | 906 |
|
|
0 commit comments