Skip to content

Example notebooks: failures and warnings #309

@melissawm

Description

@melissawm

I am reviewing the example notebooks to make sure we are able to run them automatically. I have observed a few failures and blockers that I'll list here. All of these were run in the free tier, using the TPU v2-8 runtime.

1. grpo_demo

Fails on training step with RESOURCE_EXHAUSTED:

with mesh:
  grpo_trainer.train(dataset)
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[/tmp/ipython-input-3813433836.py](https://localhost:8080/#) in <cell line: 0>()
      1 with mesh:
----> 2   grpo_trainer.train(dataset)

4 frames
    [... skipping hidden 5 frame]

[/usr/local/lib/python3.12/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in __call__(self, *args)
   1360         self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1361       else:
-> 1362         results = self.xla_executable.execute_sharded(input_bufs)
   1363 
   1364       if dispatch.needs_check_special():

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 20.25M. That was not possible. There are 3.88M free.; (0x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

2. logit_distillation

Error in Utility function:

import functools
import humanize
from tunix.rl import utils

show_hbm_usage = utils.show_hbm_usage
show_hbm_usage()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/tmp/ipython-input-2863494701.py](https://localhost:8080/#) in <cell line: 0>()
      4 
      5 show_hbm_usage = utils.show_hbm_usage
----> 6 show_hbm_usage()

1 frames
[/usr/local/lib/python3.12/dist-packages/tunix/rl/utils.py](https://localhost:8080/#) in jax_hbm_usage_gb(devices)
    131   for d in devices:
    132     stats = d.memory_stats()
--> 133     used = stats["bytes_in_use"]
    134     limit = stats["bytes_limit"]
    135     hbm_used.append((used, limit))

TypeError: 'NoneType' object is not subscriptable

3. qlora_demo

Needs wandb API key. Since we are looking to run this automatically, how can this be automated?

wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

4. qwen3_example

I've just tested this with a TPU v2-8 runtime and the notebook runs fine after authentication with Kaggle, despite a couple of user warnings:

Failure 1

from flax import nnx
import kagglehub
from tunix.models.qwen3 import model
from tunix.models.qwen3 import params

MODEL_CP_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")

config = (
    model.ModelConfig.qwen3_0_6b()
)  # pick correponding config based on model version
qwen3 = params.create_model_from_safe_tensors(MODEL_CP_PATH, config)
nnx.display(qwen3)
---------------------------------------------------------------------------
/usr/local/lib/python3.12/dist-packages/treescope/renderers.py:314: UserWarning: Ignoring error while formatting value of type <class 'flax.nnx.helpers.List'> with <function handle_via_treescope_repr_method at 0x7e06cef6c040>:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/treescope/renderers.py", line 290, in _render_subtree
    maybe_result = handler(node=node, path=path, subtree_renderer=rec)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/treescope/_internal/handlers/custom_type_handlers.py", line 65, in handle_via_treescope_repr_method
    return treescope_repr_method(path, subtree_renderer)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flax/nnx/pytreelib.py", line 702, in __treescope_repr__
    if name.startswith('_'):
       ^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'startswith'

...

Failure 2

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CP_PATH)
---------------------------------------------------------------------------
/usr/local/lib/python3.12/dist-packages/torch_xla/experimental/gru.py:113: SyntaxWarning: invalid escape sequence '\_'
  * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
/usr/local/lib/python3.12/dist-packages/torch_xla/__init__.py:258: UserWarning: `tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when using PyTorch/XLA. To silence this warning, `pip uninstall -y tensorflow && pip install tensorflow-cpu`. If you are in a notebook environment such as Colab or Kaggle, restart your notebook runtime afterwards.
  warnings.warn(

Related issues:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions