-
Notifications
You must be signed in to change notification settings - Fork 485
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug report
I am attempting to run GRPO with Qwen3-4B following the instructions on the single-host tutorial page.
This crashes due to a type mismatch in tpu_inference/models/jax/qwen3.py.
I ran the following command to install MaxText from source:
uv pip install -e .[tpu-post-train] --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
Then, I run the following script from inside the main maxtext directory (adapted from tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh).
set -ex
idx=$(date +%Y-%m-%d-%H-%M)
MODEL_NAME='qwen3-4b'
export MODEL_VARIATION='4b'
HF_GOLDEN_MODEL='Qwen/Qwen3-4B'
TOKENIZER_PATH='Qwen/Qwen3-4B'
export MODEL_BUCKET=${BASE_OUTPUT_DIRECTORY}/qwen3_end_to_end
python -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
model_name=${MODEL_NAME} \
hf_access_token=${HF_TOKEN} \
base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \
scan_layers=false
export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items
export RL_RUN_NAME=runner_grpo_${idx}
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL_NAME} \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
run_name=${RL_RUN_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
hf_access_token=${HF_TOKEN} \
chips_per_vm=${CHIPS_PER_VM} \
scan_layers=false
Please set the following environment variables before running the script:
export BASE_OUTPUT_DIRECTORY=...
export HF_TOKEN=...
export CHIPS_PER_VM=...
Logs/Output
Here is the traceback:
[rank0]: Traceback (most recent call last):
[rank0]: File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]: File "<frozen runpy>", line 88, in _run_code
[rank0]: File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/train_rl.py", line 659, in <module>
[rank0]: app.run(main)
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/absl/app.py", line 316, in run
[rank0]: _run_main(main, args)
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
[rank0]: sys.exit(main(argv))
[rank0]: ^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/train_rl.py", line 655, in main
[rank0]: rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
[rank0]: File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/train_rl.py", line 602, in rl_train
[rank0]: (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
[rank0]: ^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/evaluate_rl.py", line 194, in evaluate
[rank0]: multiple_call_responses = generate_responses(
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext/src/maxtext/trainers/post_train/rl/evaluate_rl.py", line 69, in generate_responses
[rank0]: responses = rl_cluster.rollout.generate(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tunix/rl/rollout/vllm_rollout.py", line 87, in generate
[rank0]: self.output = self._sampler(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tunix/generate/vllm_sampler.py", line 443, in __call__
[rank0]: outputs = self.llm.generate(
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 456, in generate
[rank0]: return self._run_completion(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1787, in _run_completion
[rank0]: return self._render_and_run_requests(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1892, in _render_and_run_requests
[rank0]: return self._run_engine(output_type, use_tqdm=use_tqdm)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1966, in _run_engine
[rank0]: step_outputs = self.llm_engine.step()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 301, in step
[rank0]: outputs = self.engine_core.get_output()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 283, in get_output
[rank0]: outputs, model_executed = self.engine_core.step_fn()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 393, in step
[rank0]: model_output = future.result()
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/.local/share/uv/python/cpython-3.12.13-linux-x86_64-gnu/lib/python3.12/concurrent/futures/_base.py", line 449, in result
[rank0]: return self.__get_result()
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/.local/share/uv/python/cpython-3.12.13-linux-x86_64-gnu/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
[rank0]: raise self._exception
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/executor/uniproc_executor.py", line 79, in collective_rpc
[rank0]: result = run_method(self.driver_worker, method, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/serial_utils.py", line 459, in run_method
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/vllm/v1/worker/worker_base.py", line 361, in execute_model
[rank0]: return self.worker.execute_model(scheduler_output)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/worker/tpu_worker.py", line 350, in execute_model
[rank0]: output = self.model_runner.execute_model(scheduler_output,
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/utils.py", line 359, in wrapper
[rank0]: result = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/runner/tpu_runner.py", line 606, in execute_model
[rank0]: output = self._execute_model(scheduler_output,
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/runner/tpu_runner.py", line 798, in _execute_model
[rank0]: aux_hidden_states) = self.model_fn(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/common/model_loader.py", line 290, in run_model
[rank0]: return model(*args)
[rank0]: ^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen3.py", line 352, in __call__
[rank0]: kv_caches, x = self.model(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen2.py", line 363, in __call__
[rank0]: kv_cache, x = layer(
[rank0]: ^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen2.py", line 277, in __call__
[rank0]: kv_cache, attn_output = self.self_attn(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/models/jax/qwen3.py", line 172, in __call__
[rank0]: new_kv_cache, outputs = attention(
[rank0]: ^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/layers/common/attention_interface.py", line 403, in attention
[rank0]: output, kv_cache = sharded_ragged_paged_attention(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/layers/common/attention_interface.py", line 358, in sharded_ragged_paged_attention
[rank0]: return jax.shard_map(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/layers/common/attention_interface.py", line 349, in _ragged_paged_attention
[rank0]: return func(
[rank0]: ^^^^^
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 1419, in ragged_paged_attention
[rank0]: static_validate_inputs(
[rank0]: File "/home/amahanka/maxtext_latest/lib/python3.12/site-packages/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 1261, in static_validate_inputs
[rank0]: raise ValueError(
[rank0]: ValueError: Expected kv_cache.dtype=dtype(bfloat16) to be equal to k.dtype=dtype(bfloat16) and v.dtype=dtype('float32').
Environment Information
I encounter this bug with several Git commit hashes of MaxText, currently the commit hash I am using is: 5dfd6a4
JAX version: 0.8.3
Python version: 3.12.13
TPU v4-8
OS: tpu-ubuntu2204-base
Additional Context
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working