Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
import re
from concurrent.futures import ThreadPoolExecutor
import math

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

from etils import epath

import jax
from tqdm import tqdm
from jax import numpy as jnp
from jax.sharding import SingleDeviceSharding, NamedSharding, PartitionSpec as P
import numpy as np

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

import torch
import torch.utils.dlpack

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

from safetensors.torch import load_file

from .third_party import modeling_deepseek as deepseek
Expand Down
1 change: 1 addition & 0 deletions deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax import random
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

from tqdm import tqdm


Expand Down
3 changes: 2 additions & 1 deletion deepseek_r1_jax/deepseek_r1_jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from functools import partial
from typing import Callable
import tempfile
from etils import epath
import gzip
import json
from pathlib import Path

from etils import epath

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

import jax
import jax.numpy as jnp
from jax import random
Expand Down
4 changes: 3 additions & 1 deletion deepseek_r1_jax/scripts/get_params_metadata_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import json
import safetensors
from pathlib import Path

import safetensors

from tqdm import tqdm

if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions llama3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import dataclasses
import os

from etils import epath
import json
from pprint import pprint
Expand Down Expand Up @@ -41,9 +43,9 @@ def encode_input(tokenizer, texts: list[str], model_name: str, pad_id: int = 0):

if __name__ == "__main__":
jax.distributed.initialize()
quant = True
quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1)

ckpt_path = epath.Path("~/bucket/DeepSeek-R1-Distill-Llama-3.1-70B-Instruct").expanduser()
ckpt_path = epath.Path("~").expanduser() / "bucket" / "DeepSeek-R1-Distill-Llama-3.1-70B-Instruct"
if quant:
ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant"
tokenizer = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json")
Expand Down
13 changes: 7 additions & 6 deletions llama3/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ requires-python = ">=3.10"
license = { text = "Apache-2.0" }

dependencies = [
"datasets",
"etils",
"gcsfs",
"huggingface-hub",
"jax",
"torch",
"transformers", # for the model config and the tokenizer
"tqdm",
"numpy",
"orbax-checkpoint",
"datasets",
"gcsfs",
"etils",
"torch",
"tqdm",
"transformers", # for the model config and the tokenizer
]

# we don't need CUDA torch
Expand Down
5 changes: 3 additions & 2 deletions llama3/scripts/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from argparse import ArgumentParser
import dataclasses
import shutil
import os.path

def main(model_path: str | Path, ckpt_path: str | Path):
try:
Expand All @@ -21,9 +22,9 @@ def main(model_path: str | Path, ckpt_path: str | Path):
from tqdm import tqdm

model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser()
files = list(model_path.glob("**/*safetensors"))
files = list(model_path.glob(os.path.join("**", "*safetensors")))
assert len(files) > 1
config_files = list(model_path.glob("**/config.json"))
config_files = list(model_path.glob(os.path.join("**", "config.json")))
assert len(config_files) == 1, "Must have only one `config.json` file in the model path"
config = AutoConfig.from_pretrained(config_files[0])
cfg = l3jax.llama_to_jax_config(config)
Expand Down
3 changes: 2 additions & 1 deletion llama3/scripts/download_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import os.path
from argparse import ArgumentParser
from pathlib import Path

Expand Down Expand Up @@ -27,7 +28,7 @@ def main(model_id: str, dest_root_path: str | Path):
parser.add_argument(
"--dest-root-path",
required=True,
default="~/",
default=os.path.join(os.path.expanduser("~"), ""),
help="Destination root directory, the model will be saved into its own directory.",
)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion llama4/.gitignore → llama4_jax/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ __pycache__/
build/**

.venv
.vscode
.vscode
8 changes: 4 additions & 4 deletions llama4/README.md → llama4_jax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ This is a pure JAX implementation of Llama 4 inference, including a checkpoint
converter for the weights. It currently runs on TPU. Support for GPU is
in-progress.

The entire model is defined in [model.py](llama4_jax/model.py) and invoked
via [main.py](main.py). Among other things, the model code demonstrates:
The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep main.py separate, I think an explicit script might be more clear that it's just a starting point for a larger program rather than default behavior for the llama4 module

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdyro It's your call. But it is very nonstandard Python, the __main__.py properly integrates… seeing itself as within Python module; being relocatable; and installable.

One other idea is to create a hierarchy:

.
├── jax_examples_cli
├── deepseek_r1_jax
└── llama4_jax

All packages are installable. There are no main.pys. Outside of jax_examples_cli there are no __main__.pys. In the Command Line Interface it finds installed packages and/or packages within a specific location (e.g., os.getcwd() or JAX_LLM_EXAMPLES env var); indicating where to find packages compatible with the jax_examples_cli/__main__.py.

Already your main.pys are very similar, it wouldn't be hard to hoist them up. Usage would be: jax_examples_cli --model <deepseek-r1-jax | llama4_jax> --ckpt_path …

via `python3 -m llam4_jax`. Among other things, the model code demonstrates:
* an MLA attention implementation;
* expert and tensor-parallelism via JAX's
[`shard_map`](https://docs.jax.dev/en/latest/sharded-computation.html#manual-parallelism-with-shard-map)
Expand Down Expand Up @@ -47,12 +47,12 @@ the full model. We've tested on v5e-64.

Run on all hosts in the TPU cluster:
```
$ python3 main.py
$ python3 -m llam4_jax
```
e.g. for Cloud TPU:
```
$ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \
--command="cd ~/llama4_jax && python3 main.py"
--command="cd ~/llama4_jax && python3 -m llam4_jax"
```

Responses:
Expand Down
Empty file.
7 changes: 4 additions & 3 deletions llama4/main.py → llama4_jax/llama4_jax/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import dataclasses
import os

from etils import epath
import json
from pprint import pformat

import jax
from jax import numpy as jnp
Expand All @@ -40,9 +41,9 @@ def encode_input(tokenizer, texts, pad_id: int = 0):

if __name__ == "__main__":
jax.distributed.initialize()
quant = True
quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1)

ckpt_path = epath.Path("~/bucket/Llama-4-Scout-Instruct").expanduser()
ckpt_path = epath.Path("~").expanduser() / "bucket" / "Llama-4-Scout-Instruct"
if quant:
ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant"
tokenizer = l4jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import dataclasses
from typing import Any

import jax
from jax import numpy as jnp
from jax.sharding import PartitionSpec as P

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

import torch

from tqdm import tqdm

from . import model as l4jax
from llama4_jax import model as l4jax


def quantize_model(ckpt_path: Path, quant_ckpt_path: Path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P, use_mesh
from jax.experimental.shard import auto_axes, reshard

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

from etils import epath

from . import ragged_attention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

import numpy as np

NUM_LANES = 128
Expand Down
16 changes: 9 additions & 7 deletions llama4/pyproject.toml → llama4_jax/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "llama4_jax"
version = "0.1.0"
description = ""
description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

authors = [
{ name = "Robert Dyro" },
]
Expand All @@ -10,15 +10,17 @@ requires-python = ">=3.10"
license = { text = "Apache-2.0" }

dependencies = [
"datasets",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please undo import sorting

"etils",
"gcsfs",
"huggingface-hub",
"jax",
"torch",
"transformers", # for the model config and the tokenizer
"tqdm",
"numpy",
"orbax-checkpoint",
"datasets",
"gcsfs",
"etils",
"torch",
"tqdm",
"transformers", # for the model config and the tokenizer
"tune_jax",
]

# we don't need CUDA torch
Expand Down
Empty file added llama4_jax/scripts/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
#!/usr/bin/env python3

import sys
from pathlib import Path
from argparse import ArgumentParser
import dataclasses
import shutil

from llama4_jax import model as l4jax
from llama4_jax import chkpt_utils as utils

def main(model_path: str | Path, ckpt_path: str | Path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importing at top level makes for a very slow help message, import under main

try:
from llama4_jax import model as l4jax
from llama4_jax import chkpt_utils as utils
except ImportError:
sys.path.append(str(Path(__file__).parents[1].absolute()))

from llama4_jax import model as l4jax
from llama4_jax import chkpt_utils as utils

def main(model_path: str | Path, ckpt_path: str | Path):
from transformers import AutoConfig
from safetensors import safe_open
from tqdm import tqdm
Expand All @@ -43,7 +36,7 @@ def main(model_path: str | Path, ckpt_path: str | Path):

additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
for additional_file in additional_files:
full_paths = list(model_path.glob(f"**/{additional_file}"))
full_paths = list(model_path.glob(Path("~").expanduser() / "DeepSeek-R1-Distill-Llama-70B"))
if len(full_paths) != 1:
print(f"Found more than 1 file for {additional_file}")
if len(full_paths) == 0:
Expand Down
File renamed without changes.
File renamed without changes.
6 changes: 4 additions & 2 deletions qwen3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import dataclasses
import os

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

from etils import epath
import json

Expand All @@ -38,9 +40,9 @@ def encode_input(tokenizer, texts, pad_id: int = 0):

if __name__ == "__main__":
#jax.distributed.initialize() # if you want to run multi-host
quant = True
quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1)

ckpt_path = epath.Path("~/bucket/qwen3_jax/Qwen3-30B-A3B").expanduser()
ckpt_path = epath.Path("~").expanduser() / "bucket" / "qwen3_jax" / "Qwen3-30B-A3B"
if quant:
ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant"
tokenizer = q3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json")
Expand Down
13 changes: 7 additions & 6 deletions qwen3/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ requires-python = ">=3.10"
license = { text = "Apache-2.0" }

dependencies = [
"datasets",
"etils",
"gcsfs",
"huggingface-hub",
"jax",
"torch",
"transformers", # for the model config and the tokenizer
"tqdm",
"numpy",
"orbax-checkpoint",
"datasets",
"gcsfs",
"etils",
"torch",
"tqdm",
"transformers", # for the model config and the tokenizer
]

# we don't need CUDA torch
Expand Down
12 changes: 6 additions & 6 deletions qwen3/scripts/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from pathlib import Path
from pprint import pprint
import os.path
from argparse import ArgumentParser
import dataclasses
import shutil
Expand All @@ -22,9 +22,9 @@ def main(model_path: str | Path, ckpt_path: str | Path):
from tqdm import tqdm

model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser()
files = list(model_path.glob("**/*safetensors"))
files = list(model_path.glob(os.path.join("**", "*safetensors")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't mix pathlib and os.path

assert len(files) > 1
config_files = list(model_path.glob("**/config.json"))
config_files = list(model_path.glob(os.path.join("**", "config.json")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't mix pathlib and os.path

assert len(config_files) == 1, "Must have only one `config.json` file in the model path"
config = AutoConfig.from_pretrained(config_files[0])
cfg = q3jax.hf_to_jax_config(config)
Expand All @@ -43,7 +43,7 @@ def main(model_path: str | Path, ckpt_path: str | Path):

additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
for additional_file in additional_files:
full_paths = list(model_path.glob(f"**/{additional_file}"))
full_paths = list(model_path.glob(os.path.join("**", additional_file)))
if len(full_paths) != 1:
print(f"Found more than 1 file for {additional_file}")
if len(full_paths) == 0:
Expand All @@ -55,11 +55,11 @@ def main(model_path: str | Path, ckpt_path: str | Path):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--source-path", default="~/Qwen3-30B-A3B", required=True, help="HF model directory path"
"--source-path", default=Path("~").expanduser() / "Qwen3-30B-A3B", required=True, help="HF model directory path"
)
parser.add_argument(
"--dest-path",
default="~/qwen3_jax/Qwen3-30B-A3B",
default=Path("~").expanduser() / "qwen3_jax" / "Qwen3-30B-A3B",
required=True,
help="JAX model model directory (to be created).",
)
Expand Down
2 changes: 1 addition & 1 deletion qwen3/scripts/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main(path: str | Path, suffix: str):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--path", default="~/Qwen3-30B-A3B", required=True, help="Existing JAX model checkpoint path"
"--path", default=Path("~").expanduser() / "Qwen3-30B-A3B", required=True, help="Existing JAX model checkpoint path"
)
parser.add_argument(
"--suffix",
Expand Down