Skip to content

GRPO with Qwen3-4B #3331

@amahankali10

Description

@amahankali10

Bug report

I am attempting to run GRPO with Qwen3-4B following the instructions on the single-host tutorial page.
This crashes due to a type mismatch in tpu_inference/models/jax/qwen3.py.

I ran the following command to install MaxText from source:

uv pip install -e .[tpu-post-train] --resolution=lowest
install_maxtext_tpu_post_train_extra_deps

Then, I run the following script from inside the main maxtext directory (adapted from tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh).

set -ex
idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='qwen3-4b'
export MODEL_VARIATION='4b'
HF_GOLDEN_MODEL='Qwen/Qwen3-4B'
TOKENIZER_PATH='Qwen/Qwen3-4B'

export MODEL_BUCKET=${BASE_OUTPUT_DIRECTORY}/qwen3_end_to_end

python -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
    model_name=${MODEL_NAME} \
    hf_access_token=${HF_TOKEN} \
    base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \
    scan_layers=false

export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items

export RL_RUN_NAME=runner_grpo_${idx}
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
  model_name=${MODEL_NAME} \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${UNSCANNED_CKPT_PATH} \
  run_name=${RL_RUN_NAME} \
  base_output_directory=${BASE_OUTPUT_DIRECTORY} \
  hf_access_token=${HF_TOKEN} \
  chips_per_vm=${CHIPS_PER_VM} \
  scan_layers=false

Please set the following environment variables before running the script:

export BASE_OUTPUT_DIRECTORY=...
export HF_TOKEN=...
export CHIPS_PER_VM=...

Logs/Output

Here is the traceback:

[rank0]: Traceback (most recent call last):
[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/train_rl.py", line 659, in <module>
[rank0]:     app.run(main)
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/absl/app.py", line 316, in run
[rank0]:     _run_main(main, args)
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
[rank0]:     sys.exit(main(argv))
[rank0]:              ^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/train_rl.py", line 655, in main
[rank0]:     rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
[rank0]:   File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/train_rl.py", line 602, in rl_train
[rank0]:     (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
[rank0]:                                                                     ^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/evaluate_rl.py", line 194, in evaluate
[rank0]:     multiple_call_responses = generate_responses(
[rank0]:                               ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/evaluate_rl.py", line 69, in generate_responses
[rank0]:     responses = rl_cluster.rollout.generate(
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tunix/rl/rollout/vllm_rollout.py", line 87, in generate
[rank0]:     self.output = self._sampler(
[rank0]:                   ^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tunix/generate/vllm_sampler.py", line 443, in __call__
[rank0]:     outputs = self.llm.generate(
[rank0]:               ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 456, in generate
[rank0]:     return self._run_completion(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1787, in _run_completion
[rank0]:     return self._render_and_run_requests(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1892, in _render_and_run_requests
[rank0]:     return self._run_engine(output_type, use_tqdm=use_tqdm)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1966, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 301, in step
[rank0]:     outputs = self.engine_core.get_output()
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 283, in get_output
[rank0]:     outputs, model_executed = self.engine_core.step_fn()
[rank0]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 393, in step
[rank0]:     model_output = future.result()
[rank0]:                    ^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/.local/share/uv/python/cpython-3.12.13-linux-x86_64-gnu/lib/python3.12/concurrent/futures/_base.py", line 449, in result
[rank0]:     return self.__get_result()
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/.local/share/uv/python/cpython-3.12.13-linux-x86_64-gnu/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
[rank0]:     raise self._exception
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/executor/uniproc_executor.py", line 79, in collective_rpc
[rank0]:     result = run_method(self.driver_worker, method, args, kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/serial_utils.py", line 459, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/worker/worker_base.py", line 361, in execute_model
[rank0]:     return self.worker.execute_model(scheduler_output)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/worker/tpu_worker.py", line 350, in execute_model
[rank0]:     output = self.model_runner.execute_model(scheduler_output,
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/utils.py", line 359, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/runner/tpu_runner.py", line 606, in execute_model
[rank0]:     output = self._execute_model(scheduler_output,
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/runner/tpu_runner.py", line 798, in _execute_model
[rank0]:     aux_hidden_states) = self.model_fn(
[rank0]:                          ^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/common/model_loader.py", line 290, in run_model
[rank0]:     return model(*args)
[rank0]:            ^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen3.py", line 352, in __call__
[rank0]:     kv_caches, x = self.model(
[rank0]:                    ^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen2.py", line 363, in __call__
[rank0]:     kv_cache, x = layer(
[rank0]:                   ^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen2.py", line 277, in __call__
[rank0]:     kv_cache, attn_output = self.self_attn(
[rank0]:                             ^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen3.py", line 172, in __call__
[rank0]:     new_kv_cache, outputs = attention(
[rank0]:                             ^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/layers/common/attention_interface.py", line 403, in attention
[rank0]:     output, kv_cache = sharded_ragged_paged_attention(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/layers/common/attention_interface.py", line 358, in sharded_ragged_paged_attention
[rank0]:     return jax.shard_map(
[rank0]:            ^^^^^^^^^^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/layers/common/attention_interface.py", line 349, in _ragged_paged_attention
[rank0]:     return func(
[rank0]:            ^^^^^
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 1419, in ragged_paged_attention
[rank0]:     static_validate_inputs(
[rank0]:   File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 1261, in static_validate_inputs
[rank0]:     raise ValueError(
[rank0]: ValueError: Expected kv_cache.dtype=dtype(bfloat16) to be equal to k.dtype=dtype(bfloat16) and v.dtype=dtype('float32').

Environment Information

I encounter this bug with several Git commit hashes of MaxText, currently the commit hash I am using is: 5dfd6a4

JAX version: 0.8.3

Python version: 3.12.13

TPU v4-8

OS: tpu-ubuntu2204-base

Additional Context

No response

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions