Skip to content

Commit 698b08d

Browse files
njzjzcoderabbitai[bot]Your Name
authored
feat(jax): SavedModel C++ interface (including DPA-2 supports) (#4307)
Including nlist and no nlist interface. The limitation: A SavedModel created on a device cannot be run on another. For example, a CUDA model cannot be run on the CPU. The model is generated using #4336. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Added support for the JAX backend, including specific model and checkpoint file formats. - Introduced a new shell script for model conversion to enhance usability. - Updated installation documentation to clarify JAX support and requirements. - New section in documentation detailing limitations of the JAX backend with LAMMPS. - **Bug Fixes** - Enhanced error handling for model initialization and backend compatibility. - **Documentation** - Updated backend documentation to include JAX details and limitations. - Improved clarity in installation instructions for both TensorFlow and JAX. - **Tests** - Added comprehensive unit tests for JAX integration with the Deep Potential class. - Expanded test coverage for LAMMPS integration with DeepMD. - **Chores** - Updated CMake configurations and workflow files for improved testing and dependency management. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Your Name <[email protected]>
1 parent 85e5e20 commit 698b08d

File tree

24 files changed

+12703
-19
lines changed

24 files changed

+12703
-19
lines changed

.github/workflows/test_cc.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ jobs:
2727
mpi: mpich
2828
- uses: lukka/get-cmake@latest
2929
- run: python -m pip install uv
30-
- run: source/install/uv_with_retry.sh pip install --system tensorflow
30+
- name: Install Python dependencies
31+
run: |
32+
source/install/uv_with_retry.sh pip install --system tensorflow-cpu
33+
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
34+
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py
35+
- name: Convert models
36+
run: source/tests/infer/convert-models.sh
3137
- name: Download libtorch
3238
run: |
3339
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip
@@ -47,12 +53,6 @@ jobs:
4753
CMAKE_GENERATOR: Ninja
4854
CXXFLAGS: ${{ matrix.check_memleak && '-fsanitize=leak' || '' }}
4955
# test lammps
50-
- run: |
51-
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
52-
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp] mpi4py
53-
env:
54-
DP_BUILD_TESTING: 1
55-
if: ${{ !matrix.check_memleak }}
5656
- run: pytest --cov=deepmd source/lmp/tests
5757
env:
5858
OMP_NUM_THREADS: 1

.github/workflows/test_cuda.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
runs-on: nvidia
2020
# https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845
2121
container:
22-
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
22+
image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04
2323
options: --gpus all
2424
if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group')
2525
steps:
@@ -63,12 +63,15 @@ jobs:
6363
CUDA_VISIBLE_DEVICES: 0
6464
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
6565
XLA_PYTHON_CLIENT_PREALLOCATE: false
66+
- name: Convert models
67+
run: source/tests/infer/convert-models.sh
6668
- name: Download libtorch
6769
run: |
6870
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
6971
unzip libtorch.zip
7072
- run: |
7173
export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch
74+
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
7275
source/install/test_cc_local.sh
7376
env:
7477
OMP_NUM_THREADS: 1
@@ -79,7 +82,7 @@ jobs:
7982
DP_VARIANT: cuda
8083
DP_USE_MPICH2: 1
8184
- run: |
82-
export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH
85+
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$LD_LIBRARY_PATH
8386
export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH
8487
python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1)
8588
python -m pytest source/ipi/tests

doc/backend.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
3131
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
3232
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
3333
`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow.
34-
Currently, this backend is developed actively, and has no support for training and the C++ interface.
34+
Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface.
35+
The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
36+
Currently, this backend is developed actively, and has no support for training.
3537

3638
### DP {{ dpmodel_icon }}
3739

doc/install/install-from-source.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte
297297

298298
::::{tab-set}
299299

300-
:::{tab-item} TensorFlow {{ tensorflow_icon }}
300+
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}
301+
302+
The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library.
301303

302304
Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually.
303305

@@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html
338340

339341
::::{tab-set}
340342

341-
:::{tab-item} TensorFlow {{ tensorflow_icon }}
343+
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}
342344

343345
I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake
344346

@@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value
375377

376378
**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`
377379

378-
{{ tensorflow_icon }} Whether building the TensorFlow backend.
380+
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
379381

380382
:::
381383

@@ -391,7 +393,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value
391393

392394
**Type**: `PATH`
393395

394-
{{ tensorflow_icon }} The Path to TensorFlow's C++ interface.
396+
{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface.
395397

396398
:::
397399

doc/model/dpa2.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ If one runs LAMMPS with MPI, the customized OP library for the C++ interface sho
1818
If one runs LAMMPS with MPI and CUDA devices, it is recommended to compile the customized OP library for the C++ interface with a [CUDA-Aware MPI](https://developer.nvidia.com/mpi-solutions-gpus) library and CUDA,
1919
otherwise the communication between GPU cards falls back to the slower CPU implementation.
2020

21+
## Limiations of the JAX backend with LAMMPS {{ jax_icon }}
22+
23+
When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.
24+
25+
```lammps
26+
atom_modify map yes
27+
```
28+
29+
See the example `examples/water/lmp/jax_dpa2.lammps`.
30+
2131
## Data format
2232

2333
DPA-2 supports both the [standard data format](../data/system.md) and the [mixed type data format](../data/system.md#mixed-type).
221 KB
Binary file not shown.

examples/water/lmp/jax_dpa2.lammps

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
3+
# bulk water
4+
5+
units metal
6+
boundary p p p
7+
atom_style atomic
8+
# Below line is required when using DPA-2 with the JAX backend
9+
atom_modify map yes
10+
11+
neighbor 2.0 bin
12+
neigh_modify every 10 delay 0 check no
13+
14+
read_data water.lmp
15+
mass 1 16
16+
mass 2 2
17+
18+
# See https://deepmd.rtfd.io/lammps/ for usage
19+
pair_style deepmd frozen_model.savedmodel
20+
# If atom names (O H in this example) are not set in the pair_coeff command, the type_map defined by the training parameter will be used by default.
21+
pair_coeff * * O H
22+
23+
velocity all create 330.0 23456789
24+
25+
fix 1 all nvt temp 330.0 330.0 0.5
26+
timestep 0.0005
27+
thermo_style custom step pe ke etotal temp press vol
28+
thermo 100
29+
dump 1 all custom 100 water.dump id type x y z
30+
31+
run 1000

source/api_c/include/c_api.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ extern "C" {
1212
/** C API version. Bumped whenever the API is changed.
1313
* @since API version 22
1414
*/
15-
#define DP_C_API_VERSION 24
15+
#define DP_C_API_VERSION 25
1616

1717
/**
1818
* @brief Neighbor list.
@@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_,
3131
int* ilist_,
3232
int* numneigh_,
3333
int** firstneigh_);
34-
/*
34+
/**
3535
* @brief Create a new neighbor list with communication capabilities.
3636
* @details This function extends DP_NewNlist by adding support for parallel
3737
* communication, allowing the neighbor list to be used in distributed
@@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
6868
int* recvproc,
6969
void* world);
7070

71-
/*
71+
/**
7272
* @brief Set mask for a neighbor list.
7373
*
7474
* @param nl Neighbor list.
@@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
7878
**/
7979
extern void DP_NlistSetMask(DP_Nlist* nl, int mask);
8080

81+
/**
82+
* @brief Set mapping for a neighbor list.
83+
*
84+
* @param nl Neighbor list.
85+
* @param mapping mapping from all atoms to real atoms, in size nall.
86+
* @since API version 25
87+
*
88+
**/
89+
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
90+
8191
/**
8292
* @brief Delete a neighbor list.
8393
*

source/api_c/include/deepmd.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,11 @@ struct InputNlist {
863863
* @brief Set mask for this neighbor list.
864864
*/
865865
void set_mask(int mask) { DP_NlistSetMask(nl, mask); };
866+
/**
867+
* @brief Set mapping for this neighbor list.
868+
* @param mapping mapping from all atoms to real atoms, in size nall.
869+
*/
870+
void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
866871
};
867872

868873
/**

source/api_c/src/c_api.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
4343
return new_nl;
4444
}
4545
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); }
46+
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
47+
nl->nl.set_mapping(mapping);
48+
}
4649
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }
4750

4851
// DP Base Model

0 commit comments

Comments
 (0)