-
Notifications
You must be signed in to change notification settings - Fork 33
Open
Description
I am reviewing the example notebooks to make sure we are able to run them automatically. I have observed a few failures and blockers that I'll list here. All of these were run in the free tier, using the TPU v2-8 runtime.
1. grpo_demo
Fails on training step with RESOURCE_EXHAUSTED:
with mesh:
grpo_trainer.train(dataset)
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[/tmp/ipython-input-3813433836.py](https://localhost:8080/#) in <cell line: 0>()
1 with mesh:
----> 2 grpo_trainer.train(dataset)
4 frames
[... skipping hidden 5 frame]
[/usr/local/lib/python3.12/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in __call__(self, *args)
1360 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1361 else:
-> 1362 results = self.xla_executable.execute_sharded(input_bufs)
1363
1364 if dispatch.needs_check_special():
XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 20.25M. That was not possible. There are 3.88M free.; (0x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
2. logit_distillation
Error in Utility function:
import functools
import humanize
from tunix.rl import utils
show_hbm_usage = utils.show_hbm_usage
show_hbm_usage()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[/tmp/ipython-input-2863494701.py](https://localhost:8080/#) in <cell line: 0>()
4
5 show_hbm_usage = utils.show_hbm_usage
----> 6 show_hbm_usage()
1 frames
[/usr/local/lib/python3.12/dist-packages/tunix/rl/utils.py](https://localhost:8080/#) in jax_hbm_usage_gb(devices)
131 for d in devices:
132 stats = d.memory_stats()
--> 133 used = stats["bytes_in_use"]
134 limit = stats["bytes_limit"]
135 hbm_used.append((used, limit))
TypeError: 'NoneType' object is not subscriptable
3. qlora_demo
Needs wandb API key. Since we are looking to run this automatically, how can this be automated?
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:
4. qwen3_example
I've just tested this with a TPU v2-8 runtime and the notebook runs fine after authentication with Kaggle, despite a couple of user warnings:
Failure 1
from flax import nnx
import kagglehub
from tunix.models.qwen3 import model
from tunix.models.qwen3 import params
MODEL_CP_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")
config = (
model.ModelConfig.qwen3_0_6b()
) # pick correponding config based on model version
qwen3 = params.create_model_from_safe_tensors(MODEL_CP_PATH, config)
nnx.display(qwen3)
---------------------------------------------------------------------------
/usr/local/lib/python3.12/dist-packages/treescope/renderers.py:314: UserWarning: Ignoring error while formatting value of type <class 'flax.nnx.helpers.List'> with <function handle_via_treescope_repr_method at 0x7e06cef6c040>:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/treescope/renderers.py", line 290, in _render_subtree
maybe_result = handler(node=node, path=path, subtree_renderer=rec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/treescope/_internal/handlers/custom_type_handlers.py", line 65, in handle_via_treescope_repr_method
return treescope_repr_method(path, subtree_renderer)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flax/nnx/pytreelib.py", line 702, in __treescope_repr__
if name.startswith('_'):
^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'startswith'
...
Failure 2
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_CP_PATH)
---------------------------------------------------------------------------
/usr/local/lib/python3.12/dist-packages/torch_xla/experimental/gru.py:113: SyntaxWarning: invalid escape sequence '\_'
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
/usr/local/lib/python3.12/dist-packages/torch_xla/__init__.py:258: UserWarning: `tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when using PyTorch/XLA. To silence this warning, `pip uninstall -y tensorflow && pip install tensorflow-cpu`. If you are in a notebook environment such as Colab or Kaggle, restart your notebook runtime afterwards.
warnings.warn(
Related issues:
Metadata
Metadata
Assignees
Labels
No labels