Skip to content

Question: how to run tunix with vLLM? #256

@OhadRubin

Description

@OhadRubin

Hey,
Would you guys mind pinning the requirements for vLLM usage? or generally providing a bit more information about it?
I'm not able to use tunix in that setup.
Minimal steps to recreate on a fresh TPUv4:

cd ~
git clone https://github.com/vllm-project/vllm
git clone https://github.com/google/tunix
cd tunix
jsonl_to_json(){ python -c 'import sys,json;p=sys.argv[1];d=[json.loads(l) for l in open(p)];open(p,"w").write(json.dumps(d))' "$1"; }
mkdir -p root-dir/rl/grpo/data
wget https://raw.githubusercontent.com/openai/grade-school-math/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/data/train.jsonl -O root-dir/rl/grpo/data/gsm8k_train2.json
wget https://raw.githubusercontent.com/openai/grade-school-math/2909d34ef28520753df82a2234c357259d254aa8/grade_school_math/data/test.jsonl -O root-dir/rl/grpo/data/gsm8k_test.json
jsonl_to_json root-dir/rl/grpo/data/gsm8k_train.json
jsonl_to_json root-dir/rl/grpo/data/gsm8k_test.json

# build the vllm docker container like https://docs.vllm.ai/en/v0.5.5/getting_started/tpu-installation.html instructs
(cd ~/vllm && sudo docker build -f docker/Dockerfile.tpu -t vllm-tpu .) 
# run it in bash mode with the tunix repo mounted
(cd ~/vllm && sudo docker run --privileged --net host --shm-size=16G -v /home/$USER/tunix:/workspace/tunix -it vllm-tpu)
# in the bash of the container
cd /workspace/tunix
python3 -m pip install git+https://github.com/google/qwix grain tensorboardx jaxtyping huggingface-hub==0.34.4
python3 -m pip install --no-deps -e .  #otherwise this overrides the vllm tpulib 0.18
python3 scripts/grpo_demo_llama3_qwen2.py --root-dir=root_dir --model-version=Qwen/Qwen2.5-0.5B-Instruct

Error trace:

root@v4-8-node-1:/workspace/tunix# python3 scripts/grpo_demo_llama3_qwen2.py --root-dir=root_dir --model-version=Qwen/Qwen2.5-0.5B-Instruct
/workspace/tunix/scripts/grpo_demo_llama3_qwen2.py:275: SyntaxWarning: invalid escape sequence '\$'
  
This script is still WIP and you'll need to download all the data tolocal first. Functionality and performance is not guaranteed. Try at your own discretion
Directory does not exist: root_dir/rl/grpo/demo/experiments/llama3/training_runs/2
INFO:absl: - Pathways not available. Using defaultHBM stats collector
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_0(process=0,(0,0,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_1(process=0,(1,0,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_2(process=0,(0,1,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_3(process=0,(1,1,0,0))
train_dataset size: 1869, val_dataset size:0,test_dataset size: 50
{'answer': array(['3', '34', '300', '35'], dtype='<U3'),
 'prompts': array(['<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nMaria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?<end_of_turn>\n<start_of_turn>model<|im_end|>\n<|im_start|>assistant\n',
       '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nA wildlife team is monitoring the number of birds in a park. There are 3 blackbirds in each of the park’s 7 trees. There are also 13 magpies roaming around the park. How many birds are in the park in total?<end_of_turn>\n<start_of_turn>model<|im_end|>\n<|im_start|>assistant\n',
       '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nLast year, the school library purchased 50 new books. This year, it purchased 3 times as many books. If the library had 100 books before it purchased new books last year, how many books are in the library now?<end_of_turn>\n<start_of_turn>model<|im_end|>\n<|im_start|>assistant\n',
       '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nJame gets 20 singing lessons.  He gets the first lesson free and after the first 10 paid lessons he only needs to pay for every other lesson.  Each lesson is $5.  His uncle pays for half.  How much does James pay?<end_of_turn>\n<start_of_turn>model<|im_end|>\n<|im_start|>assistant\n'],
      dtype='<U636'),
 'question': array(['Maria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?',
       'A wildlife team is monitoring the number of birds in a park. There are 3 blackbirds in each of the park’s 7 trees. There are also 13 magpies roaming around the park. How many birds are in the park in total?',
       'Last year, the school library purchased 50 new books. This year, it purchased 3 times as many books. If the library had 100 books before it purchased new books last year, how many books are in the library now?',
       'Jame gets 20 singing lessons.  He gets the first lesson free and after the first 10 paid lessons he only needs to pay for every other lesson.  Each lesson is $5.  His uncle pays for half.  How much does James pay?'],
      dtype='<U213')}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.05s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 290/290 [00:00<00:00, 403.70it/s]
INFO:absl:After creating the reference lora model - Pathways not available. Using defaultHBM stats collector
INFO:absl:Using 475.3 MiB / 30.7 GiB (0.0150971112916562) on TPU_0(process=0,(0,0,0,0))
INFO:absl:Using 474.6 MiB / 30.7 GiB (0.015072592546090434) on TPU_1(process=0,(1,0,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_2(process=0,(0,1,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_3(process=0,(1,1,0,0))
INFO:absl:After creating a raw sampler - Pathways not available. Using defaultHBM stats collector
INFO:absl:Using 475.3 MiB / 30.7 GiB (0.0150971112916562) on TPU_0(process=0,(0,0,0,0))
INFO:absl:Using 474.6 MiB / 30.7 GiB (0.015072592546090434) on TPU_1(process=0,(1,0,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_2(process=0,(0,1,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_3(process=0,(1,1,0,0))
INFO:absl:After creating a new rollout worker - Pathways not available. Using defaultHBM stats collector
INFO:absl:Using 475.3 MiB / 30.7 GiB (0.0150971112916562) on TPU_0(process=0,(0,0,0,0))
INFO:absl:Using 474.6 MiB / 30.7 GiB (0.015072592546090434) on TPU_1(process=0,(1,0,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_2(process=0,(0,1,0,0))
INFO:absl:Using 13.0 KiB / 30.7 GiB (4.0321782714101377e-07) on TPU_3(process=0,(1,1,0,0))
INFO 09-05 08:26:20 [__init__.py:241] Automatically detected platform tpu.
INFO 09-05 08:26:21 [tpu.py:235] tpu_commons not found, using vLLM's TpuPlatform
INFO 09-05 08:26:22 [utils.py:328] non-default args: {'max_model_len': 1536, 'tensor_parallel_size': 2, 'gpu_memory_utilization': 0.2, 'disable_log_stats': True, 'model': 'root_dir/rl/grpo/models/Qwen/Qwen2.5-0.5B-Instruct'}
WARNING 09-05 08:26:22 [__init__.py:557] The global random seed is set to 0. Since VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may affect the random state of the Python process that launched vLLM.
INFO 09-05 08:26:30 [__init__.py:748] Resolved architecture: Qwen2ForCausalLM
`torch_dtype` is deprecated! Use `dtype` instead!
INFO 09-05 08:26:30 [__init__.py:1786] Using max model len 1536
INFO 09-05 08:26:30 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 09-05 08:26:30 [tpu.py:114] [TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph.
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 09-05 08:26:33 [core.py:76] Initializing a V1 LLM engine (v0.10.2rc2.dev98+ge599e2c65) with config: model='root_dir/rl/grpo/models/Qwen/Qwen2.5-0.5B-Instruct', speculative_config=None, tokenizer='root_dir/rl/grpo/models/Qwen/Qwen2.5-0.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1536, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=None, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=root_dir/rl/grpo/models/Qwen/Qwen2.5-0.5B-Instruct, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":2,"debug_dump_path":"","cache_dir":"","backend":"openxla","custom_ops":[],"splitting_ops":null,"use_inductor":true,"compile_sizes":null,"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":0,"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":null,"cudagraph_copy_inputs":false,"full_cuda_graph":false,"pass_config":{},"max_capture_size":null,"local_cache_dir":null}
WARNING 09-05 08:26:33 [multiproc_worker_utils.py:273] Reducing Torch parallelism from 120 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 09-05 08:26:33 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1], buffer_handle=(2, 16777216, 10, 'psm_70330afd'), local_subscribe_addr='ipc:///tmp/5d75bd27-a3f4-4c9c-a96f-2c9e5f87df2c', remote_subscribe_addr=None, remote_addr_ipv6=False)
/usr/local/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
INFO 09-05 08:26:33 [importing.py:43] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 09-05 08:26:33 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.
INFO 09-05 08:26:33 [importing.py:43] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 09-05 08:26:33 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.
INFO 09-05 08:26:33 [tpu_worker.py:35] tpu_commons not found, using vLLM's TPUWorker.
WARNING 09-05 08:26:33 [tpu.py:170] Pin memory is not supported on TPU.
INFO 09-05 08:26:33 [tpu_worker.py:35] tpu_commons not found, using vLLM's TPUWorker.
WARNING 09-05 08:26:33 [tpu.py:170] Pin memory is not supported on TPU.
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_fa746b34'), local_subscribe_addr='ipc:///tmp/3ad45fdd-7fd0-44dc-b556-feef89561d6b', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_55c15548'), local_subscribe_addr='ipc:///tmp/835e0f66-194b-4c37-bf74-5f6e5f54c8b7', remote_subscribe_addr=None, remote_addr_ipv6=False)
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_communicator.py:23] tpu_commons not found, using vLLM's TpuCommunicator
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_communicator.py:66] TpuCommunicator initialized with MP
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_communicator.py:23] tpu_commons not found, using vLLM's TpuCommunicator
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_communicator.py:66] TpuCommunicator initialized with MP
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_1d683d9d'), local_subscribe_addr='ipc:///tmp/d7d24686-58d5-4eb0-bd60-7a3547c2503e', remote_subscribe_addr=None, remote_addr_ipv6=False)
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_communicator.py:66] TpuCommunicator initialized with MP
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_communicator.py:66] TpuCommunicator initialized with MP
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [parallel_state.py:1134] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [parallel_state.py:1134] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1, EP rank 1
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1861] Using exponential token paddings:
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1861] Using exponential token paddings:
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     16
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     16
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     32
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     32
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     64
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     64
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     128
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     128
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     256
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     256
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     512
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     1024
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     512
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     2048
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     1024
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     4096
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     2048
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     8192
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     4096
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1863]     8192
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1827] Preparing request paddings:
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1827] Preparing request paddings:
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     8
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     8
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     16
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     16
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     32
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     32
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     64
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     64
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     128
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     128
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     256
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1834]     256
(VllmWorker TP0 pid=7644) INFO 09-05 08:26:33 [tpu_model_runner.py:1172] Loading model from scratch...
(VllmWorker TP1 pid=7646) INFO 09-05 08:26:33 [tpu_model_runner.py:1172] Loading model from scratch...
https://symbolize.stripped_domain/r/?trace=7f0a46ef863d,7f0b1c7ccddf,7f0a3e0968e7,7f0a3e0950ee,7f0a3e094b55,7f0a3e090c36,7f0a3aabab6d,7f0a3cd1b19d,7f0a3aa4ae32,7ee5c30bfae9,7ee5c30c012c,7ee5bc925f37,7ee5bc6260d0,7ee5bc630d47,7f0a55c364e3&map= 
*** SIGSEGV (@0x7f061515ec48), see go/stacktraces#s15 received by PID 7644 (TID 7644) on cpu 60; stack trace: ***
https://symbolize.stripped_domain/r/?trace=7f0a46ef863d,7f0b1c7ccddf,7f0a3e0968e7,7f0a3e0950ee,7f0a3e094b55,7f0a3e090c36,7f0a3aabab6d,7f0a3cd1b19d,7f0a3aa4ae32,7ee5c30bfae9,7ee5c30c012c,7ee5bc925f37,7ee5bc6260d0,7ee5bc630d47,7f0a55c364e3&map= 
*** SIGSEGV (@0x7f061515ec48), see go/stacktraces#s15 received by PID 7646 (TID 7646) on cpu 61; stack trace: ***
PC: @     0x7f0a46ef863d  (unknown)  (unknown)
PC: @     0x7f0a46ef863d  (unknown)  (unknown)
    @     0x7f0a46e6bfe5       1904  (unknown)
    @     0x7f0a46e6bfe5       1904  (unknown)
    @     0x7f0b1c7ccde0  2146701248  (unknown)
    @     0x7f0b1c7ccde0  2146701248  (unknown)
    @     0x7f0a3e0968e8         64  (unknown)
    @     0x7f0a3e0968e8         64  (unknown)
    @     0x7f0a3e0950ef        240  (unknown)
    @     0x7f0a3e0950ef        240  (unknown)
    @     0x7f0a3e094b56         64  (unknown)
    @     0x7f0a3e094b56         64  (unknown)
    @     0x7f0a3e090c37         64  (unknown)
    @     0x7f0a3e090c37         64  (unknown)
    @     0x7f0a3aabab6e       2256  (unknown)
    @     0x7f0a3aabab6e       2256  (unknown)
    @     0x7f0a3cd1b19e       1360  (unknown)
    @     0x7f0a3cd1b19e       1360  (unknown)
    @     0x7f0a3aa4ae33        944  (unknown)
    @     0x7f0a3aa4ae33        944  (unknown)
    @     0x7ee5c30bfaea       1104  xla::PjRtCApiClient::BufferFromHostBufferInternalImpl()
    @     0x7ee5c30bfaea       1104  xla::PjRtCApiClient::BufferFromHostBufferInternalImpl()
    @     0x7ee5c30c012d        224  xla::PjRtCApiClient::BufferFromHostBuffer()
    @     0x7ee5c30c012d        224  xla::PjRtCApiClient::BufferFromHostBuffer()
    @     0x7ee5bc925f38       1152  torch_xla::runtime::PjRtComputationClient::TransferToDevice()
    @     0x7ee5bc925f38       1152  torch_xla::runtime::PjRtComputationClient::TransferToDevice()
    @     0x7ee5bc6260d1        992  torch_xla::TensorToXlaData()
    @     0x7ee5bc6260d1        992  torch_xla::TensorToXlaData()
    @     0x7ee5bc630d48         32  torch_xla::XlaBackendImpl::MakeComputationDataFromTensor()
    @     0x7ee5bc630d48         32  torch_xla::XlaBackendImpl::MakeComputationDataFromTensor()
    @     0x7f0a55c364e4  (unknown)  torch::lazy::TensorToDataHandle()
    @     0x7f0a55c364e4  (unknown)  torch::lazy::TensorToDataHandle()
    @ ... and at least 3 more frames
    @ ... and at least 3 more frames
https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7f0a46ef863d,7f0a46ef863d,7f0a46e6bfe4,7f0a46e6bfe4,7f0b1c7ccddf,7f0b1c7ccddf,7f0a3e0968e7,7f0a3e0968e7,7f0a3e0950ee,7f0a3e0950ee,7f0a3e094b55,7f0a3e094b55,7f0a3e090c36,7f0a3e090c36,7f0a3aabab6d,7f0a3aabab6d,7f0a3cd1b19d,7f0a3cd1b19d,7f0a3aa4ae32,7f0a3aa4ae32,7ee5c30bfae9,7ee5c30bfae9,7ee5c30c012c,7ee5c30c012c,7ee5bc925f37,7ee5bc925f37,7ee5bc6260d0,7ee5bc6260d0,7ee5bc630d47,7ee5bc630d47,7f0a55c364e37f0a55c364e3&map=&map= 
 
E0905 08:26:33.933235    7644 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0905 08:26:33.933235    7646 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0905 08:26:33.933260    7646 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0905 08:26:33.933260    7644 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0905 08:26:33.933280    7646 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0905 08:26:33.933281    7644 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0905 08:26:33.933304    7644 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0905 08:26:33.933304    7646 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0905 08:26:33.933312    7644 coredump_hook.cc:457] RAW: Dumping core locally.
E0905 08:26:33.933314    7646 coredump_hook.cc:457] RAW: Dumping core locally.
https://symbolize.stripped_domain/r/?trace=7f0b1c781f17,7f0b1c7ccddf,7f0a46d4bb0d,7f0a46d4bc82,7f0a46d45794,7f0a46d4676a,7f0a46d45a04,7f0a46d46498,7f0a46e6c37c,7f0b1c7ccddf,7f0a3e0968e7,7f0a3e0950ee,7f0a3e094b55,7f0a3e090c36,7f0a3aabab6d,7f0a3cd1b19d,7f0a3aa4ae32,7ee5c30bfae9,7ee5c30c012c,7ee5bc925f37,7ee5bc6260d0,7ee5bc630d47,7f0a55c364e3&map= 
E0905 08:27:35.109189    7646 process_state.cc:1175] RAW: Signal 11 raised at PC: 0x7f0b1c781f17 while already in FailureSignalHandler!
E0905 08:27:35.109196    7646 process_state.cc:1210] RAW: Raising 11 signal with default behavior
https://symbolize.stripped_domain/r/?trace=7f0b1c781f17,7f0b1c7ccddf,7f0a46d4bb0d,7f0a46d4bc82,7f0a46d45794,7f0a46d4676a,7f0a46d45a04,7f0a46d46498,7f0a46e6c37c,7f0b1c7ccddf,7f0a3e0968e7,7f0a3e0950ee,7f0a3e094b55,7f0a3e090c36,7f0a3aabab6d,7f0a3cd1b19d,7f0a3aa4ae32,7ee5c30bfae9,7ee5c30c012c,7ee5bc925f37,7ee5bc6260d0,7ee5bc630d47,7f0a55c364e3&map= 
E0905 08:27:35.112495    7644 process_state.cc:1175] RAW: Signal 11 raised at PC: 0x7f0b1c781f17 while already in FailureSignalHandler!
E0905 08:27:35.112502    7644 process_state.cc:1210] RAW: Raising 11 signal with default behavior
Traceback (most recent call last):
  File "/workspace/tunix/scripts/grpo_demo_llama3_qwen2.py", line 825, in <module>
    rl_cluster = rl_cluster_lib.RLCluster(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/tunix/tunix/rl/rl_cluster.py", line 166, in __init__
    self._init_cluster()
  File "/workspace/tunix/tunix/rl/rl_cluster.py", line 289, in _init_cluster
    self._rollout = vllm_rollout.VllmRollout(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/tunix/tunix/rl/rollout/vllm_rollout.py", line 41, in __init__
    self._sampler = vllm_sampler.VllmSampler(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/tunix/tunix/generate/vllm_sampler.py", line 94, in __init__
    self.llm = LLM(**self.args)
               ^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/entrypoints/llm.py", line 272, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/engine/llm_engine.py", line 492, in from_engine_args
    return engine_cls.from_vllm_config(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/llm_engine.py", line 127, in from_vllm_config
    return cls(vllm_config=vllm_config,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/llm_engine.py", line 104, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 82, in make_client
    return InprocClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 245, in __init__
    self.engine_core = EngineCore(*args, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/core.py", line 82, in __init__
    self.model_executor = executor_class(vllm_config)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/executor/executor_base.py", line 54, in __init__
    self._init_executor()
  File "/workspace/vllm/vllm/v1/executor/multiproc_executor.py", line 96, in _init_executor
    self.workers = WorkerProc.wait_for_ready(unready_workers)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/executor/multiproc_executor.py", line 487, in wait_for_ready
    raise e from None
Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
/usr/local/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 3 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
root@v4-8-node-1:/workspace/tunix# 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions