diff --git a/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py b/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py index 701e8fa..0c96b0e 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py +++ b/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py @@ -21,6 +21,7 @@ import re from concurrent.futures import ThreadPoolExecutor import math + from etils import epath import jax @@ -28,8 +29,10 @@ from jax import numpy as jnp from jax.sharding import SingleDeviceSharding, NamedSharding, PartitionSpec as P import numpy as np + import torch import torch.utils.dlpack + from safetensors.torch import load_file from .third_party import modeling_deepseek as deepseek diff --git a/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py b/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py index 837ff75..895be1b 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py +++ b/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py @@ -21,6 +21,7 @@ from jax import random from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu + from tqdm import tqdm diff --git a/deepseek_r1_jax/deepseek_r1_jax/model.py b/deepseek_r1_jax/deepseek_r1_jax/model.py index 699eb16..1015378 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/model.py +++ b/deepseek_r1_jax/deepseek_r1_jax/model.py @@ -20,11 +20,12 @@ from functools import partial from typing import Callable import tempfile -from etils import epath import gzip import json from pathlib import Path +from etils import epath + import jax import jax.numpy as jnp from jax import random diff --git a/deepseek_r1_jax/scripts/get_params_metadata_hf.py b/deepseek_r1_jax/scripts/get_params_metadata_hf.py index d2ff0a2..5cf0244 100644 --- a/deepseek_r1_jax/scripts/get_params_metadata_hf.py +++ b/deepseek_r1_jax/scripts/get_params_metadata_hf.py @@ -13,8 +13,10 @@ # limitations under the License. import json -import safetensors from pathlib import Path + +import safetensors + from tqdm import tqdm if __name__ == "__main__": diff --git a/llama3/main.py b/llama3/main.py index 9b2cf4c..fdd9d15 100644 --- a/llama3/main.py +++ b/llama3/main.py @@ -13,6 +13,8 @@ # limitations under the License. import dataclasses +import os + from etils import epath import json from pprint import pprint @@ -41,9 +43,9 @@ def encode_input(tokenizer, texts: list[str], model_name: str, pad_id: int = 0): if __name__ == "__main__": jax.distributed.initialize() - quant = True + quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1) - ckpt_path = epath.Path("~/bucket/DeepSeek-R1-Distill-Llama-3.1-70B-Instruct").expanduser() + ckpt_path = epath.Path("~").expanduser() / "bucket" / "DeepSeek-R1-Distill-Llama-3.1-70B-Instruct" if quant: ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" tokenizer = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") diff --git a/llama3/pyproject.toml b/llama3/pyproject.toml index 100b781..3d97322 100644 --- a/llama3/pyproject.toml +++ b/llama3/pyproject.toml @@ -10,15 +10,16 @@ requires-python = ">=3.10" license = { text = "Apache-2.0" } dependencies = [ + "datasets", + "etils", + "gcsfs", + "huggingface-hub", "jax", - "torch", - "transformers", # for the model config and the tokenizer - "tqdm", "numpy", "orbax-checkpoint", - "datasets", - "gcsfs", - "etils", + "torch", + "tqdm", + "transformers", # for the model config and the tokenizer ] # we don't need CUDA torch diff --git a/llama3/scripts/convert_weights.py b/llama3/scripts/convert_weights.py index 2912c22..5c71e59 100644 --- a/llama3/scripts/convert_weights.py +++ b/llama3/scripts/convert_weights.py @@ -5,6 +5,7 @@ from argparse import ArgumentParser import dataclasses import shutil +import os.path def main(model_path: str | Path, ckpt_path: str | Path): try: @@ -21,9 +22,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): from tqdm import tqdm model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() - files = list(model_path.glob("**/*safetensors")) + files = list(model_path.glob(os.path.join("**", "*safetensors"))) assert len(files) > 1 - config_files = list(model_path.glob("**/config.json")) + config_files = list(model_path.glob(os.path.join("**", "config.json"))) assert len(config_files) == 1, "Must have only one `config.json` file in the model path" config = AutoConfig.from_pretrained(config_files[0]) cfg = l3jax.llama_to_jax_config(config) diff --git a/llama3/scripts/download_model.py b/llama3/scripts/download_model.py index b609d63..eb6f41e 100644 --- a/llama3/scripts/download_model.py +++ b/llama3/scripts/download_model.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import os.path from argparse import ArgumentParser from pathlib import Path @@ -27,7 +28,7 @@ def main(model_id: str, dest_root_path: str | Path): parser.add_argument( "--dest-root-path", required=True, - default="~/", + default=os.path.join(os.path.expanduser("~"), ""), help="Destination root directory, the model will be saved into its own directory.", ) args = parser.parse_args() diff --git a/llama4/.gitignore b/llama4_jax/.gitignore similarity index 94% rename from llama4/.gitignore rename to llama4_jax/.gitignore index 3afd13a..630d63b 100644 --- a/llama4/.gitignore +++ b/llama4_jax/.gitignore @@ -11,4 +11,4 @@ __pycache__/ build/** .venv -.vscode \ No newline at end of file +.vscode diff --git a/llama4/README.md b/llama4_jax/README.md similarity index 98% rename from llama4/README.md rename to llama4_jax/README.md index e117ab1..4667b15 100644 --- a/llama4/README.md +++ b/llama4_jax/README.md @@ -11,8 +11,8 @@ This is a pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights. It currently runs on TPU. Support for GPU is in-progress. -The entire model is defined in [model.py](llama4_jax/model.py) and invoked -via [main.py](main.py). Among other things, the model code demonstrates: +The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked +via `python3 -m llam4_jax`. Among other things, the model code demonstrates: * an MLA attention implementation; * expert and tensor-parallelism via JAX's [`shard_map`](https://docs.jax.dev/en/latest/sharded-computation.html#manual-parallelism-with-shard-map) @@ -47,12 +47,12 @@ the full model. We've tested on v5e-64. Run on all hosts in the TPU cluster: ``` -$ python3 main.py +$ python3 -m llam4_jax ``` e.g. for Cloud TPU: ``` $ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \ - --command="cd ~/llama4_jax && python3 main.py" + --command="cd ~/llama4_jax && python3 -m llam4_jax" ``` Responses: diff --git a/llama4_jax/llama4_jax/__init__.py b/llama4_jax/llama4_jax/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llama4/main.py b/llama4_jax/llama4_jax/__main__.py similarity index 94% rename from llama4/main.py rename to llama4_jax/llama4_jax/__main__.py index aa61675..009a0f1 100644 --- a/llama4/main.py +++ b/llama4_jax/llama4_jax/__main__.py @@ -13,9 +13,10 @@ # limitations under the License. import dataclasses +import os + from etils import epath import json -from pprint import pformat import jax from jax import numpy as jnp @@ -40,9 +41,9 @@ def encode_input(tokenizer, texts, pad_id: int = 0): if __name__ == "__main__": jax.distributed.initialize() - quant = True + quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1) - ckpt_path = epath.Path("~/bucket/Llama-4-Scout-Instruct").expanduser() + ckpt_path = epath.Path("~").expanduser() / "bucket" / "Llama-4-Scout-Instruct" if quant: ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" tokenizer = l4jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") diff --git a/llama4/llama4_jax/chkpt_utils.py b/llama4_jax/llama4_jax/chkpt_utils.py similarity index 99% rename from llama4/llama4_jax/chkpt_utils.py rename to llama4_jax/llama4_jax/chkpt_utils.py index 3370ff7..4507af7 100644 --- a/llama4/llama4_jax/chkpt_utils.py +++ b/llama4_jax/llama4_jax/chkpt_utils.py @@ -18,15 +18,16 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path import dataclasses -from typing import Any import jax from jax import numpy as jnp from jax.sharding import PartitionSpec as P + import torch + from tqdm import tqdm -from . import model as l4jax +from llama4_jax import model as l4jax def quantize_model(ckpt_path: Path, quant_ckpt_path: Path): diff --git a/llama4/llama4_jax/decode_ragged_dot.py b/llama4_jax/llama4_jax/decode_ragged_dot.py similarity index 100% rename from llama4/llama4_jax/decode_ragged_dot.py rename to llama4_jax/llama4_jax/decode_ragged_dot.py diff --git a/llama4/llama4_jax/model.py b/llama4_jax/llama4_jax/model.py similarity index 99% rename from llama4/llama4_jax/model.py rename to llama4_jax/llama4_jax/model.py index 2a4b8f8..27ca52b 100644 --- a/llama4/llama4_jax/model.py +++ b/llama4_jax/llama4_jax/model.py @@ -30,6 +30,7 @@ from jax.experimental.shard_map import shard_map from jax.sharding import PartitionSpec as P, use_mesh from jax.experimental.shard import auto_axes, reshard + from etils import epath from . import ragged_attention diff --git a/llama4/llama4_jax/ragged_attention.py b/llama4_jax/llama4_jax/ragged_attention.py similarity index 99% rename from llama4/llama4_jax/ragged_attention.py rename to llama4_jax/llama4_jax/ragged_attention.py index 347c346..b2b0556 100644 --- a/llama4/llama4_jax/ragged_attention.py +++ b/llama4_jax/llama4_jax/ragged_attention.py @@ -10,6 +10,7 @@ from jax.experimental.pallas import tpu as pltpu from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, PartitionSpec as P, NamedSharding + import numpy as np NUM_LANES = 128 diff --git a/llama4/pyproject.toml b/llama4_jax/pyproject.toml similarity index 84% rename from llama4/pyproject.toml rename to llama4_jax/pyproject.toml index 75d4167..28b4526 100644 --- a/llama4/pyproject.toml +++ b/llama4_jax/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "llama4_jax" version = "0.1.0" -description = "" +description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights." authors = [ { name = "Robert Dyro" }, ] @@ -10,15 +10,17 @@ requires-python = ">=3.10" license = { text = "Apache-2.0" } dependencies = [ + "datasets", + "etils", + "gcsfs", + "huggingface-hub", "jax", - "torch", - "transformers", # for the model config and the tokenizer - "tqdm", "numpy", "orbax-checkpoint", - "datasets", - "gcsfs", - "etils", + "torch", + "tqdm", + "transformers", # for the model config and the tokenizer + "tune_jax", ] # we don't need CUDA torch diff --git a/llama4_jax/scripts/__init__.py b/llama4_jax/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llama4/scripts/convert_weights.py b/llama4_jax/scripts/convert_weights.py similarity index 85% rename from llama4/scripts/convert_weights.py rename to llama4_jax/scripts/convert_weights.py index a5db741..2ea8c94 100644 --- a/llama4/scripts/convert_weights.py +++ b/llama4_jax/scripts/convert_weights.py @@ -1,22 +1,15 @@ #!/usr/bin/env python3 -import sys from pathlib import Path from argparse import ArgumentParser import dataclasses import shutil +from llama4_jax import model as l4jax +from llama4_jax import chkpt_utils as utils -def main(model_path: str | Path, ckpt_path: str | Path): - try: - from llama4_jax import model as l4jax - from llama4_jax import chkpt_utils as utils - except ImportError: - sys.path.append(str(Path(__file__).parents[1].absolute())) - - from llama4_jax import model as l4jax - from llama4_jax import chkpt_utils as utils +def main(model_path: str | Path, ckpt_path: str | Path): from transformers import AutoConfig from safetensors import safe_open from tqdm import tqdm @@ -43,7 +36,7 @@ def main(model_path: str | Path, ckpt_path: str | Path): additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] for additional_file in additional_files: - full_paths = list(model_path.glob(f"**/{additional_file}")) + full_paths = list(model_path.glob(Path("~").expanduser() / "DeepSeek-R1-Distill-Llama-70B")) if len(full_paths) != 1: print(f"Found more than 1 file for {additional_file}") if len(full_paths) == 0: diff --git a/llama4/scripts/download_model.py b/llama4_jax/scripts/download_model.py similarity index 100% rename from llama4/scripts/download_model.py rename to llama4_jax/scripts/download_model.py diff --git a/llama4/scripts/quantize_model.py b/llama4_jax/scripts/quantize_model.py similarity index 100% rename from llama4/scripts/quantize_model.py rename to llama4_jax/scripts/quantize_model.py diff --git a/qwen3/main.py b/qwen3/main.py index b768016..485ce14 100644 --- a/qwen3/main.py +++ b/qwen3/main.py @@ -13,6 +13,8 @@ # limitations under the License. import dataclasses +import os + from etils import epath import json @@ -38,9 +40,9 @@ def encode_input(tokenizer, texts, pad_id: int = 0): if __name__ == "__main__": #jax.distributed.initialize() # if you want to run multi-host - quant = True + quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1) - ckpt_path = epath.Path("~/bucket/qwen3_jax/Qwen3-30B-A3B").expanduser() + ckpt_path = epath.Path("~").expanduser() / "bucket" / "qwen3_jax" / "Qwen3-30B-A3B" if quant: ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" tokenizer = q3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") diff --git a/qwen3/pyproject.toml b/qwen3/pyproject.toml index d7fec1f..86abad0 100644 --- a/qwen3/pyproject.toml +++ b/qwen3/pyproject.toml @@ -10,15 +10,16 @@ requires-python = ">=3.10" license = { text = "Apache-2.0" } dependencies = [ + "datasets", + "etils", + "gcsfs", + "huggingface-hub", "jax", - "torch", - "transformers", # for the model config and the tokenizer - "tqdm", "numpy", "orbax-checkpoint", - "datasets", - "gcsfs", - "etils", + "torch", + "tqdm", + "transformers", # for the model config and the tokenizer ] # we don't need CUDA torch diff --git a/qwen3/scripts/convert_weights.py b/qwen3/scripts/convert_weights.py index a72b87c..2f03a3b 100644 --- a/qwen3/scripts/convert_weights.py +++ b/qwen3/scripts/convert_weights.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -from pprint import pprint +import os.path from argparse import ArgumentParser import dataclasses import shutil @@ -22,9 +22,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): from tqdm import tqdm model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() - files = list(model_path.glob("**/*safetensors")) + files = list(model_path.glob(os.path.join("**", "*safetensors"))) assert len(files) > 1 - config_files = list(model_path.glob("**/config.json")) + config_files = list(model_path.glob(os.path.join("**", "config.json"))) assert len(config_files) == 1, "Must have only one `config.json` file in the model path" config = AutoConfig.from_pretrained(config_files[0]) cfg = q3jax.hf_to_jax_config(config) @@ -43,7 +43,7 @@ def main(model_path: str | Path, ckpt_path: str | Path): additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] for additional_file in additional_files: - full_paths = list(model_path.glob(f"**/{additional_file}")) + full_paths = list(model_path.glob(os.path.join("**", additional_file))) if len(full_paths) != 1: print(f"Found more than 1 file for {additional_file}") if len(full_paths) == 0: @@ -55,11 +55,11 @@ def main(model_path: str | Path, ckpt_path: str | Path): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--source-path", default="~/Qwen3-30B-A3B", required=True, help="HF model directory path" + "--source-path", default=Path("~").expanduser() / "Qwen3-30B-A3B", required=True, help="HF model directory path" ) parser.add_argument( "--dest-path", - default="~/qwen3_jax/Qwen3-30B-A3B", + default=Path("~").expanduser() / "qwen3_jax" / "Qwen3-30B-A3B", required=True, help="JAX model model directory (to be created).", ) diff --git a/qwen3/scripts/quantize_model.py b/qwen3/scripts/quantize_model.py index 4b0a230..1594276 100644 --- a/qwen3/scripts/quantize_model.py +++ b/qwen3/scripts/quantize_model.py @@ -21,7 +21,7 @@ def main(path: str | Path, suffix: str): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--path", default="~/Qwen3-30B-A3B", required=True, help="Existing JAX model checkpoint path" + "--path", default=Path("~").expanduser() / "Qwen3-30B-A3B", required=True, help="Existing JAX model checkpoint path" ) parser.add_argument( "--suffix",