-
Notifications
You must be signed in to change notification settings - Fork 23
[llama4 -> llama4_jax] Refactor to be a proper installable Python package #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6fd621f
2286f87
d5c936a
3812791
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,15 +21,18 @@ | |
import re | ||
from concurrent.futures import ThreadPoolExecutor | ||
import math | ||
|
||
from etils import epath | ||
|
||
import jax | ||
from tqdm import tqdm | ||
from jax import numpy as jnp | ||
from jax.sharding import SingleDeviceSharding, NamedSharding, PartitionSpec as P | ||
import numpy as np | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
import torch | ||
import torch.utils.dlpack | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
from safetensors.torch import load_file | ||
|
||
from .third_party import modeling_deepseek as deepseek | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
from jax import random | ||
from jax.experimental import pallas as pl | ||
from jax.experimental.pallas import tpu as pltpu | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
from tqdm import tqdm | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,11 +20,12 @@ | |
from functools import partial | ||
from typing import Callable | ||
import tempfile | ||
from etils import epath | ||
import gzip | ||
import json | ||
from pathlib import Path | ||
|
||
from etils import epath | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
import jax | ||
import jax.numpy as jnp | ||
from jax import random | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,4 @@ __pycache__/ | |
build/** | ||
|
||
.venv | ||
.vscode | ||
.vscode |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,8 @@ This is a pure JAX implementation of Llama 4 inference, including a checkpoint | |
converter for the weights. It currently runs on TPU. Support for GPU is | ||
in-progress. | ||
|
||
The entire model is defined in [model.py](llama4_jax/model.py) and invoked | ||
via [main.py](main.py). Among other things, the model code demonstrates: | ||
The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rdyro It's your call. But it is very nonstandard Python, the One other idea is to create a hierarchy:
All packages are installable. There are no Already your |
||
via `python3 -m llam4_jax`. Among other things, the model code demonstrates: | ||
* an MLA attention implementation; | ||
* expert and tensor-parallelism via JAX's | ||
[`shard_map`](https://docs.jax.dev/en/latest/sharded-computation.html#manual-parallelism-with-shard-map) | ||
|
@@ -47,12 +47,12 @@ the full model. We've tested on v5e-64. | |
|
||
Run on all hosts in the TPU cluster: | ||
``` | ||
$ python3 main.py | ||
$ python3 -m llam4_jax | ||
``` | ||
e.g. for Cloud TPU: | ||
``` | ||
$ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \ | ||
--command="cd ~/llama4_jax && python3 main.py" | ||
--command="cd ~/llama4_jax && python3 -m llam4_jax" | ||
``` | ||
|
||
Responses: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,15 +18,16 @@ | |
from concurrent.futures import ThreadPoolExecutor | ||
from pathlib import Path | ||
import dataclasses | ||
from typing import Any | ||
|
||
import jax | ||
from jax import numpy as jnp | ||
from jax.sharding import PartitionSpec as P | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
import torch | ||
|
||
from tqdm import tqdm | ||
|
||
from . import model as l4jax | ||
from llama4_jax import model as l4jax | ||
|
||
|
||
def quantize_model(ckpt_path: Path, quant_ckpt_path: Path): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from jax.experimental.shard_map import shard_map | ||
from jax.sharding import PartitionSpec as P, use_mesh | ||
from jax.experimental.shard import auto_axes, reshard | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
from etils import epath | ||
|
||
from . import ragged_attention | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from jax.experimental.pallas import tpu as pltpu | ||
from jax.experimental.shard_map import shard_map | ||
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
import numpy as np | ||
|
||
NUM_LANES = 128 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
[project] | ||
name = "llama4_jax" | ||
version = "0.1.0" | ||
description = "" | ||
description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
authors = [ | ||
{ name = "Robert Dyro" }, | ||
] | ||
|
@@ -10,15 +10,17 @@ requires-python = ">=3.10" | |
license = { text = "Apache-2.0" } | ||
|
||
dependencies = [ | ||
"datasets", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please undo import sorting |
||
"etils", | ||
"gcsfs", | ||
"huggingface-hub", | ||
"jax", | ||
"torch", | ||
"transformers", # for the model config and the tokenizer | ||
"tqdm", | ||
"numpy", | ||
"orbax-checkpoint", | ||
"datasets", | ||
"gcsfs", | ||
"etils", | ||
"torch", | ||
"tqdm", | ||
"transformers", # for the model config and the tokenizer | ||
"tune_jax", | ||
] | ||
|
||
# we don't need CUDA torch | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,15 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import sys | ||
from pathlib import Path | ||
from argparse import ArgumentParser | ||
import dataclasses | ||
import shutil | ||
|
||
from llama4_jax import model as l4jax | ||
from llama4_jax import chkpt_utils as utils | ||
|
||
def main(model_path: str | Path, ckpt_path: str | Path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. importing at top level makes for a very slow help message, import under main |
||
try: | ||
from llama4_jax import model as l4jax | ||
from llama4_jax import chkpt_utils as utils | ||
except ImportError: | ||
sys.path.append(str(Path(__file__).parents[1].absolute())) | ||
|
||
from llama4_jax import model as l4jax | ||
from llama4_jax import chkpt_utils as utils | ||
|
||
def main(model_path: str | Path, ckpt_path: str | Path): | ||
from transformers import AutoConfig | ||
from safetensors import safe_open | ||
from tqdm import tqdm | ||
|
@@ -43,7 +36,7 @@ def main(model_path: str | Path, ckpt_path: str | Path): | |
|
||
additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] | ||
for additional_file in additional_files: | ||
full_paths = list(model_path.glob(f"**/{additional_file}")) | ||
full_paths = list(model_path.glob(Path("~").expanduser() / "DeepSeek-R1-Distill-Llama-70B")) | ||
if len(full_paths) != 1: | ||
print(f"Found more than 1 file for {additional_file}") | ||
if len(full_paths) == 0: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
# limitations under the License. | ||
|
||
import dataclasses | ||
import os | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert extra spaces in imports |
||
from etils import epath | ||
import json | ||
|
||
|
@@ -38,9 +40,9 @@ def encode_input(tokenizer, texts, pad_id: int = 0): | |
|
||
if __name__ == "__main__": | ||
#jax.distributed.initialize() # if you want to run multi-host | ||
quant = True | ||
quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1) | ||
|
||
ckpt_path = epath.Path("~/bucket/qwen3_jax/Qwen3-30B-A3B").expanduser() | ||
ckpt_path = epath.Path("~").expanduser() / "bucket" / "qwen3_jax" / "Qwen3-30B-A3B" | ||
if quant: | ||
ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" | ||
tokenizer = q3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import sys | ||
from pathlib import Path | ||
from pprint import pprint | ||
import os.path | ||
from argparse import ArgumentParser | ||
import dataclasses | ||
import shutil | ||
|
@@ -22,9 +22,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): | |
from tqdm import tqdm | ||
|
||
model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() | ||
files = list(model_path.glob("**/*safetensors")) | ||
files = list(model_path.glob(os.path.join("**", "*safetensors"))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't mix pathlib and os.path |
||
assert len(files) > 1 | ||
config_files = list(model_path.glob("**/config.json")) | ||
config_files = list(model_path.glob(os.path.join("**", "config.json"))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't mix pathlib and os.path |
||
assert len(config_files) == 1, "Must have only one `config.json` file in the model path" | ||
config = AutoConfig.from_pretrained(config_files[0]) | ||
cfg = q3jax.hf_to_jax_config(config) | ||
|
@@ -43,7 +43,7 @@ def main(model_path: str | Path, ckpt_path: str | Path): | |
|
||
additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] | ||
for additional_file in additional_files: | ||
full_paths = list(model_path.glob(f"**/{additional_file}")) | ||
full_paths = list(model_path.glob(os.path.join("**", additional_file))) | ||
if len(full_paths) != 1: | ||
print(f"Found more than 1 file for {additional_file}") | ||
if len(full_paths) == 0: | ||
|
@@ -55,11 +55,11 @@ def main(model_path: str | Path, ckpt_path: str | Path): | |
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument( | ||
"--source-path", default="~/Qwen3-30B-A3B", required=True, help="HF model directory path" | ||
"--source-path", default=Path("~").expanduser() / "Qwen3-30B-A3B", required=True, help="HF model directory path" | ||
) | ||
parser.add_argument( | ||
"--dest-path", | ||
default="~/qwen3_jax/Qwen3-30B-A3B", | ||
default=Path("~").expanduser() / "qwen3_jax" / "Qwen3-30B-A3B", | ||
required=True, | ||
help="JAX model model directory (to be created).", | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert extra spaces in imports