Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions docs/install_maxtext.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ You can find the latest commit hashes in the [JAX `build/` folder](https://githu

Next, run the `seed-env` CLI to generate the new requirements files. You will need to do this separately for the TPU and GPU environments. The generated files will be placed in a directory specified by `--output-dir`.

### For TPU
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit: [d83b06508d669add43a8875ae7fd9e9fe7abf160](https://github.com/jax-ml/jax/commit/d83b06508d669add43a8875ae7fd9e9fe7abf160).

Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step.
### TPU Pre-Training

If you have made changes to the pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

```bash
seed-env \
Expand All @@ -138,53 +140,77 @@ seed-env \
--output-dir=generated_tpu_artifacts
```

### For GPU
After generating the new requirements, you need to copy the generated files from `generated_tpu_artifacts/tpu-requirements.txt` to `src/dependencies/requirements/generated_requirements/tpu-requirements.txt`.

#### TPU Post-Training

Similarly, run the command for the GPU requirements.
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

```bash
seed-env \
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-base-requirements.txt \
--local-requirements=src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
--host-name=MaxText \
--seed-commit=<jax-build-commit-hash> \
--python-version=3.12 \
--requirements-txt=cuda12-requirements.txt \
--hardware=cuda12 \
--output-dir=generated_gpu_artifacts
--requirements-txt=tpu-post-train-requirements.txt \
--output-dir=generated_tpu_post_train_artifacts
```

## Step 4: Update Project Files
After generating the new requirements, you need to copy the generated files from `generated_tpu_post_train_artifacts/tpu-post-train-requirements.txt` to `src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt`.

After generating the new requirements, you need to update the files in the MaxText repository.
### GPU Pre-Training

1. **Copy the generated files:**
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
```bash
seed-env \
--local-requirements=src/dependencies/requirements/base_requirements/gpu-base-requirements.txt \
--host-name=MaxText \
--seed-commit=<jax-build-commit-hash> \
--python-version=3.12 \
--requirements-txt=cuda12-requirements.txt \
--hardware=cuda12 \
--output-dir=generated_gpu_artifacts
```

2. **Update `extra_deps_from_github.txt` (if necessary):**
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
After generating the new requirements, you need to copy the generated files from `generated_gpu_artifacts/cuda12-requirements.txt` to `src/dependencies/requirements/generated_requirements/cuda12-requirements.txt`.

## Step 5: Verify the New Dependencies
## Step 4: Verify the New Dependencies

Finally, test that the new dependencies install correctly and that MaxText runs as expected.

1. **Create a clean environment:** It's best to start with a fresh Python virtual environment.

```bash
# Ensure uv is installed
pip install uv

# Create and activate the virtual environment
uv venv --python 3.12 --seed maxtext_venv
source maxtext_venv/bin/activate
```

2. **Run the setup script:** Execute `bash setup.sh` to install the new dependencies.
2. **Install MaxText and dependencies**: Install the package in editable mode with the appropriate extras. Choose the command that matches your hardware:

**TPU Pre-Training**:

```bash
pip install uv
# install the tpu package
uv pip install -e .[tpu] --resolution=lowest
# or install the gpu package by running the following line:
# uv pip install -e .[cuda12] --resolution=lowest
install_maxtext_github_deps
install_maxtext_tpu_github_deps
```

**TPU Post-Training**:

```bash
uv pip install -e .[tpu-post-train] --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
```

**GPU Pre-Training**:

```bash
uv pip install -e .[cuda12] --resolution=lowest
install_maxtext_cuda12_github_dep
```

3. **Run tests:** Run MaxText tests to ensure there are no regressions.
3. **Verify the installation**: Run MaxText tests to ensure everything is working as expected with the newly installed dependencies and there are no regressions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
absl-py
aqtp
array-record
chex
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
cloud-tpu-diagnostics!=1.1.14
datasets
drjax
flax
Expand Down Expand Up @@ -40,9 +41,7 @@ tensorflow-datasets
tensorflow-text
tensorflow
tiktoken
tokamax
tokamax!=0.1.0
transformers
uvloop
qwix
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
-r requirements.txt
google-tunix
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
ipykernel
papermill
Loading
Loading