Skip to content

Commit 6236cb6

Browse files
committed
docs: document JAX backend
Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent fa61d69 commit 6236cb6

File tree

11 files changed

+45
-10
lines changed

11 files changed

+45
-10
lines changed

doc/_static/jax.svg

Lines changed: 1 addition & 0 deletions
Loading

doc/backend.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t
2323
[PyTorch](https://pytorch.org/) 2.0 or above is required.
2424
While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint.
2525

26+
### JAX {{ jax_icon }}
27+
28+
- Model filename extension: `.xlo`
29+
- Checkpoint filename extension: `.jax`
30+
31+
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
32+
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
33+
Currently, this backend is developed actively, and has no support for training and the C++ interface.
34+
2635
### DP {{ dpmodel_icon }}
2736

2837
:::{note}

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
myst_substitutions = {
168168
"tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""",
169169
"pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""",
170+
"jax_icon": """![JAX](/_static/jax.svg){class=platform-icon}""",
170171
"dpmodel_icon": """![DP](/_static/logo_icon.svg){class=platform-icon}""",
171172
}
172173

doc/env.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod
3131
- If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices.
3232
- {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used.
3333
- {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used.
34+
- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used.
3435

3536
## Python interface only
3637

doc/install/install-from-source.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal
7878

7979
:::
8080

81+
:::{tab-item} JAX {{ jax_icon }}
82+
83+
To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run
84+
85+
```sh
86+
pip install jax-ai-stack
87+
```
88+
89+
One can also install packages in JAX AI Stack manually.
90+
Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA.
91+
92+
One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org).
93+
94+
:::
95+
8196
::::
8297

8398
It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by

doc/model/sel.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H
2424

2525
:::
2626

27+
:::{tab-item} JAX {{ jax_icon }}
28+
29+
```sh
30+
dp --jax neighbor-stat -s data -r 6.0 -t O H
31+
```
32+
33+
:::
34+
2735
::::
2836

2937
where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data.

doc/model/train-energy.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
1+
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
22

33
:::{note}
4-
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
4+
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
55
:::
66

77
In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file.

doc/model/train-fitting-dos.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
1+
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
22

33
:::{note}
4-
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
4+
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
55
:::
66

77
Here we present an API to DeepDOS model, which can be used to fit electronic density of state (DOS) (which is a vector).

doc/model/train-se-atten.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
1+
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
22

33
:::{note}
4-
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
4+
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
55
:::
66

77
## DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation

doc/model/train-se-e2-a.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
1+
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
22

33
:::{note}
4-
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
4+
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
55
:::
66

77
The notation of `se_e2_a` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The `e2` stands for the embedding with two-atoms information. This descriptor was described in detail in [the DeepPot-SE paper](https://arxiv.org/abs/1805.09003).

0 commit comments

Comments
 (0)