Skip to content

On Policy Distillation hangs while training #1934

@pumetu

Description

@pumetu

I am attempting to run OPD with a custom environment, however the trainer hangs in the middle of training. Sometimes it would hang around the 10th step sometimes 90th. Looking through the logs, the orchestrator has already errored out, however, the trainer still has not exited out and hangs there indefinitely. The orchestrator mentioned receiving a max_tokens of 0, but I am unsure when this was set and looking through the code its hardcoded to be a least 1?

Config

max_steps = 100
seq_len = 2048
output_dir = "outputs/Qwen3-1.7B-Base-distill-Qwen3.5-27B-typhoon-s-distill"

[deployment]
type = "single_node"
num_teacher_gpus = 2

[wandb]
project = "med-tool-use"
offline = true

[model]
name = "/project/lt200394-thllmV/thaillm-dev-models/Qwen3-1.7B-Base"

[orchestrator]
batch_size = 64
rollouts_per_example = 4

[orchestrator.teacher_model.model]
name = "/project/lt200394-thllmV/thaillm-dev-models/Qwen3.5-27B"

[orchestrator.sampling]
max_tokens = 512

[orchestrator.buffer]
skip_verification = true

[[orchestrator.env]]
id = "typhoon_s_distill"
args = { dataset_name = "/home/ptuchind/dataset/typhoon-s-instruct-post-training" }

[trainer]

[trainer.loss]
teacher_tau = 1.0
adv_tau = 0.0  # Disable reward-based learning

[inference]

Command: uv run --offline @opd.toml

Logs

Trainer logs

[default0]:�[2m20:40:35�[0m �[1m   INFO�[0m �[1m�[22mInjecting Prime LM head with chunk size None�[0m�[0m
[default0]:�[2m20:40:36�[0m �[1m   INFO�[0m �[1m�[22mBuilding 1-D device mesh with ['dp_shard'], [1]�[0m�[0m
[default0]:�[2m20:40:37�[0m �[33m�[1mWARNING�[0m �[33m�[1m�[22mModel uses tied word embeddings, so skipping the last-layer no-reshard optimization.�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mInitializing tokenizer (name='/project/lt200394-thllmV/thaillm-dev-models/Qwen3-1.7B-Base' trust_remote_code=None chat_template=None)�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mSetting up loss function (type='default' ratio_type='token' token_mask_high=8.0 token_mask_low=0.125 sequence_clip_high=10.0 geo_mask_high=10.0 geo_mask_low=0.1 sequence_mask_low=0.0 sequence_mask_high=100.0 adv_tau=0.0 teacher_tau=1.0 kl_tau=0.0)�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mInitializing optimizer (lr=1e-06 weight_decay=0.01 max_norm=1.0 type='adamw' betas1=0.9 betas2=0.999)�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mUsing `constant` scheduler (type='constant')�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mInitializing weight broadcast (type='filesystem' save_sharded=True save_format='safetensors')�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mStarting from step 0 (total_tokens=0, total_samples=0)�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mInitializing data loader (fake=None)�[0m�[0m
[default0]:�[2m20:40:37�[0m �[1m   INFO�[0m �[1m�[22mStarting training loop (max_steps=100)�[0m�[0m
[default0]:�[2m20:40:51�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 0 | Time: 13.33s | Loss: -0.0000 | Entropy: 1.9949 | Mismatch KL: 0.0014 | Grad. Norm: 0.2675 | LR: 1.00e-06 | Throughput: 0 tokens/s | MFU: 0.0% | Peak Mem.: 34.9 GiB�[0m�[0m
[default0]:�[2m20:41:03�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 1 | Time: 6.93s | Loss: -0.0002 | Entropy: 2.0880 | Mismatch KL: 0.0019 | Grad. Norm: 0.3135 | LR: 1.00e-06 | Throughput: 2800 tokens/s | MFU: 10.5% | Peak Mem.: 38.5 GiB�[0m�[0m
[default0]:�[2m20:41:15�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 2 | Time: 6.55s | Loss: -0.0000 | Entropy: 1.9693 | Mismatch KL: 0.0013 | Grad. Norm: 0.2313 | LR: 1.00e-06 | Throughput: 2850 tokens/s | MFU: 10.7% | Peak Mem.: 38.3 GiB�[0m�[0m
[default0]:�[2m20:41:32�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 3 | Time: 9.35s | Loss: -0.0001 | Entropy: 2.2666 | Mismatch KL: 0.0015 | Grad. Norm: 0.2595 | LR: 1.00e-06 | Throughput: 2889 tokens/s | MFU: 10.9% | Peak Mem.: 38.4 GiB�[0m�[0m
[default0]:�[2m20:41:44�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 4 | Time: 6.53s | Loss: -0.0002 | Entropy: 1.4445 | Mismatch KL: 0.0012 | Grad. Norm: 0.1993 | LR: 1.00e-06 | Throughput: 2811 tokens/s | MFU: 10.6% | Peak Mem.: 38.3 GiB�[0m�[0m
..............
[default0]:�[2m20:59:31�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 84 | Time: 5.80s | Loss: -0.0001 | Entropy: 2.3357 | Mismatch KL: 0.0014 | Grad. Norm: 0.1519 | LR: 1.00e-06 | Throughput: 2734 tokens/s | MFU: 10.3% | Peak Mem.: 38.2 GiB�[0m�[0m
[default0]:�[2m20:59:44�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 85 | Time: 6.20s | Loss: -0.0002 | Entropy: 1.9497 | Mismatch KL: 0.0014 | Grad. Norm: 0.1390 | LR: 1.00e-06 | Throughput: 2673 tokens/s | MFU: 10.1% | Peak Mem.: 38.5 GiB�[0m�[0m
[default0]:�[2m20:59:57�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 86 | Time: 6.78s | Loss: -0.0003 | Entropy: 2.1782 | Mismatch KL: 0.0022 | Grad. Norm: 0.1315 | LR: 1.00e-06 | Throughput: 2664 tokens/s | MFU: 10.0% | Peak Mem.: 38.5 GiB�[0m�[0m
[default0]:�[2m21:00:12�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 87 | Time: 8.19s | Loss: 0.0000 | Entropy: 1.7198 | Mismatch KL: 0.0013 | Grad. Norm: 0.1887 | LR: 1.00e-06 | Throughput: 2730 tokens/s | MFU: 10.3% | Peak Mem.: 38.2 GiB�[0m�[0m
[default0]:�[2m21:00:26�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 88 | Time: 7.87s | Loss: -0.0002 | Entropy: 1.8186 | Mismatch KL: 0.0012 | Grad. Norm: 0.1583 | LR: 1.00e-06 | Throughput: 2810 tokens/s | MFU: 10.6% | Peak Mem.: 38.5 GiB�[0m�[0m
[default0]:�[2m21:00:40�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 89 | Time: 8.44s | Loss: -0.0001 | Entropy: 1.9336 | Mismatch KL: 0.0013 | Grad. Norm: 0.1564 | LR: 1.00e-06 | Throughput: 2793 tokens/s | MFU: 10.5% | Peak Mem.: 38.4 GiB�[0m�[0m
[default0]:�[2m21:00:53�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 90 | Time: 6.58s | Loss: -0.0003 | Entropy: 2.1272 | Mismatch KL: 0.0014 | Grad. Norm: 0.1838 | LR: 1.00e-06 | Throughput: 2778 tokens/s | MFU: 10.4% | Peak Mem.: 38.4 GiB�[0m�[0m
[default0]:�[2m21:01:07�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 91 | Time: 7.30s | Loss: -0.0000 | Entropy: 2.0223 | Mismatch KL: 0.0015 | Grad. Norm: 0.6518 | LR: 1.00e-06 | Throughput: 2769 tokens/s | MFU: 10.4% | Peak Mem.: 38.5 GiB�[0m�[0m
[default0]:�[2m21:01:21�[0m �[32m�[1mSUCCESS�[0m �[32m�[1m�[22mStep 92 | Time: 8.15s | Loss: -0.0003 | Entropy: 2.3708 | Mismatch KL: 0.0014 | Grad. Norm: 0.1632 | LR: 1.00e-06 | Throughput: 2770 tokens/s | MFU: 10.4% | Peak Mem.: 38.5 GiB�[0m�[0m
W (HANGS HERE INDEFINITELY)

Orchestrator logs

Generating rollouts (train):   0%|          | 0/64 [00:00<?, ?it/s]
Generating rollouts (train):   6%|▋         | 4/64 [00:05<01:18,  1.31s/it]
Generating rollouts (train):  12%|█▎        | 8/64 [00:05<00:34,  1.63it/s]
Generating rollouts (train):  25%|██▌       | 16/64 [00:05<00:11,  4.18it/s]
Generating rollouts (train):  44%|████▍     | 28/64 [00:11<00:13,  2.73it/s]
Generating rollouts (train): 100%|██████████| 64/64 [00:11<00:00,  8.83it/s]
Generating rollouts (train): 100%|██████████| 64/64 [00:11<00:00,  5.44it/s]
�[2m21:01:26�[0m �[1m   INFO�[0m �[1m�[22mDetected 15/64 rollouts (gibberish=15)�[0m�[0m
�[2m21:01:26�[0m �[1m   INFO�[0m �[1m�[22mComputing teacher logprobs for 64 training examples�[0m�[0m
�[2m21:01:26�[0m �[31m�[1m  ERROR�[0m �[31m�[1m�[22mFatal error in orchestrate�[0m�[0m
�[33m�[1mTraceback (most recent call last):�[0m

  File "/lustrefs/disk/home/ptuchind/prime-rl/.venv/bin/orchestrator", line 10, in <module>
    sys.exit(main())
    │   │    └ <function main at 0x7f4c5477fec0>
    │   └ <bound method ExitHooks.exit of <wandb.sdk.lib.exit_hooks.ExitHooks object at 0x7f4c45a649b0>>
    └ <module 'sys' (built-in)>

  File "�[32m/lustrefs/disk/home/ptuchind/prime-rl/src/prime_rl/orchestrator/�[0m�[32m�[1morchestrator.py�[0m", line �[33m837�[0m, in �[35mmain�[0m
    �[1masyncio�[0m�[35m�[1m.�[0m�[1mrun�[0m�[1m(�[0m�[1morchestrate�[0m�[1m(�[0m�[1mparse_argv�[0m�[1m(�[0m�[1mOrchestratorConfig�[0m�[1m)�[0m�[1m)�[0m�[1m)�[0m
    �[36m│       │   │           │          └ �[0m�[36m�[1m<class 'prime_rl.configs.orchestrator.OrchestratorConfig'>�[0m
    �[36m│       │   │           └ �[0m�[36m�[1m<function parse_argv at 0x7f4c98716ca0>�[0m
    �[36m│       │   └ �[0m�[36m�[1m<function orchestrate at 0x7f4c76bf0860>�[0m
    �[36m│       └ �[0m�[36m�[1m<function run at 0x7f5289f76a20>�[0m
    �[36m└ �[0m�[36m�[1m<module 'asyncio' from '/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/a...�[0m

  File "/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           │      │   └ <coroutine object orchestrate at 0x7f4c546ea110>
           │      └ <function Runner.run at 0x7f5289b8d760>
           └ <asyncio.runners.Runner object at 0x7f4c685986e0>
  File "/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           │    │     │                  └ <Task pending name='Task-1' coro=<orchestrate() running at /lustrefs/disk/home/ptuchind/prime-rl/src/prime_rl/utils/utils.py:...
           │    │     └ <function BaseEventLoop.run_until_complete at 0x7f5289b83240>
           │    └ <_UnixSelectorEventLoop running=True closed=False debug=False>
           └ <asyncio.runners.Runner object at 0x7f4c685986e0>
  File "/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/asyncio/base_events.py", line 678, in run_until_complete
    self.run_forever()
    │    └ <function BaseEventLoop.run_forever at 0x7f5289b831a0>
    └ <_UnixSelectorEventLoop running=True closed=False debug=False>
  File "/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/asyncio/base_events.py", line 645, in run_forever
    self._run_once()
    │    └ <function BaseEventLoop._run_once at 0x7f5289b8cfe0>
    └ <_UnixSelectorEventLoop running=True closed=False debug=False>
  File "/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/asyncio/base_events.py", line 1999, in _run_once
    handle._run()
    │      └ <function Handle._run at 0x7f5289cba8e0>
    └ <Handle Task.task_wakeup(<_GatheringFu...ode': 400}}")>)>
  File "/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
    │    │            │    │           │    └ <member '_args' of 'Handle' objects>
    │    │            │    │           └ <Handle Task.task_wakeup(<_GatheringFu...ode': 400}}")>)>
    │    │            │    └ <member '_callback' of 'Handle' objects>
    │    │            └ <Handle Task.task_wakeup(<_GatheringFu...ode': 400}}")>)>
    │    └ <member '_context' of 'Handle' objects>
    └ <Handle Task.task_wakeup(<_GatheringFu...ode': 400}}")>)>

> File "�[32m/lustrefs/disk/home/ptuchind/prime-rl/src/prime_rl/utils/�[0m�[32m�[1mutils.py�[0m", line �[33m129�[0m, in �[35masync_wrapper�[0m
    �[1mret�[0m �[35m�[1m=�[0m �[35m�[1mawait�[0m �[1mfunc�[0m�[1m(�[0m�[35m�[1m*�[0m�[1margs�[0m�[1m,�[0m �[35m�[1m**�[0m�[1mkwargs�[0m�[1m)�[0m
    �[36m            │     │       └ �[0m�[36m�[1m{}�[0m
    �[36m            │     └ �[0m�[36m�[1m(OrchestratorConfig(toml_files=None, client=ClientConfig(timeout=1200, base_url=['http://localhost:8000/v1'], api_key_var='VL...�[0m
    �[36m            └ �[0m�[36m�[1m<function orchestrate at 0x7f4c54752e80>�[0m

  File "�[32m/lustrefs/disk/home/ptuchind/prime-rl/src/prime_rl/orchestrator/�[0m�[32m�[1morchestrator.py�[0m", line �[33m544�[0m, in �[35morchestrate�[0m
    �[1mteacher_logprobs_list�[0m �[35m�[1m=�[0m �[35m�[1mawait�[0m �[1mcompute_teacher_logprobs�[0m�[1m(�[0m
    �[36m                              └ �[0m�[36m�[1m<function compute_teacher_logprobs at 0x7f4c5477d940>�[0m

  File "�[32m/lustrefs/disk/home/ptuchind/prime-rl/src/prime_rl/orchestrator/�[0m�[32m�[1mutils.py�[0m", line �[33m176�[0m, in �[35mcompute_teacher_logprobs�[0m
    �[35m�[1mreturn�[0m �[35m�[1mawait�[0m �[1masyncio�[0m�[35m�[1m.�[0m�[1mgather�[0m�[1m(�[0m�[35m�[1m*�[0m�[1m[�[0m�[1m_compute_single�[0m�[1m(�[0m�[1mclient�[0m�[1m,�[0m �[1msample�[0m�[1m)�[0m �[35m�[1mfor�[0m �[1mclient�[0m�[1m,�[0m �[1msample�[0m �[35m�[1min�[0m �[1mzip�[0m�[1m(�[0m�[1mcycle�[0m�[1m(�[0m�[1mclients�[0m�[1m)�[0m�[1m,�[0m �[1msamples�[0m�[1m)�[0m�[1m]�[0m�[1m)�[0m
    �[36m             │       │        │                                                         │     │         └ �[0m�[36m�[1m[TrainingSample(prompt_ids=[151644, 872, 198, 40, 2776, 3330, 369, 1045, 1550, 62936, 81895, 369, 847, 220, 19, 42, 8718, 320...�[0m
    �[36m             │       │        │                                                         │     └ �[0m�[36m�[1m[ClientConfig(client_idx=0, client_type='openai_chat_completions', api_key_var='VLLM_API_KEY', api_base_url='http://localhost...�[0m
    �[36m             │       │        │                                                         └ �[0m�[36m�[1m<class 'itertools.cycle'>�[0m
    �[36m             │       │        └ �[0m�[36m�[1m<function compute_teacher_logprobs.<locals>._compute_single at 0x7f4bafca2700>�[0m
    �[36m             │       └ �[0m�[36m�[1m<function gather at 0x7f5289b7bba0>�[0m
    �[36m             └ �[0m�[36m�[1m<module 'asyncio' from '/lustrefs/disk/home/ptuchind/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/a...�[0m

  File "�[32m/lustrefs/disk/home/ptuchind/prime-rl/src/prime_rl/orchestrator/�[0m�[32m�[1mutils.py�[0m", line �[33m157�[0m, in �[35m_compute_single�[0m
    �[1mresponse�[0m �[35m�[1m=�[0m �[35m�[1mawait�[0m �[1mclient�[0m�[35m�[1m.�[0m�[1mpost�[0m�[1m(�[0m
    �[36m                 │      └ �[0m�[36m�[1m<function AsyncAPIClient.post at 0x7f4c790f0040>�[0m
    �[36m                 └ �[0m�[36m�[1m<openai.AsyncOpenAI object at 0x7f4bafb5efc0>�[0m

  File "/lustrefs/disk/home/ptuchind/prime-rl/.venv/lib/python3.12/site-packages/openai/_base_client.py", line 1794, in post
    return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)
                 │    │       │        │            │                  └ None
                 │    │       │        │            └ False
                 │    │       │        └ FinalRequestOptions(method='post', url='/chat/completions/tokens', params={}, headers=NOT_GIVEN, max_retries=NOT_GIVEN, timeo...
                 │    │       └ <class 'prime_rl.orchestrator.patches.monkey_patch_chat_completion_logprobs.<locals>.ModdedChatCompletion'>
                 │    └ <function AsyncAPIClient.request at 0x7f4c790eb7e0>
                 └ <openai.AsyncOpenAI object at 0x7f4bafb5efc0>
  File "/lustrefs/disk/home/ptuchind/prime-rl/.venv/lib/python3.12/site-packages/openai/_base_client.py", line 1594, in request
    raise self._make_status_error_from_response(err.response) from None
          │    └ <function BaseClient._make_status_error_from_response at 0x7f4c790e8ae0>
          └ <openai.AsyncOpenAI object at 0x7f4bafb5efc0>

�[31m�[1mopenai.BadRequestError�[0m:�[1m Error code: 400 - {'error': {'message': 'max_tokens must be at least 1, got 0. (parameter=max_tokens, value=0)', 'type': 'BadRequestError', 'param': 'max_tokens', 'code': 400}}�[0m

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions