Skip to content

Commit dcbf607

Browse files
authored
feat(jax): build nlist in the SavedModel & fix nopbc for StableHLO and SavedModel (#4318)
Per our discussion, use TF to build the neighbor list in the SavedModel format. Also, fix a bug when the number of ghost atoms is zero. The polymorphic_shape needs to be larger than 1, and `nghost == 0` triggered the error. Previously, I also tried `nall` or `nghost - 1` but none of them worked. Finally, I export two different functions... So now four functions are stored in the model: calculate virial or not, x nghost is zero or not. The tests for nopbc are added. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Enhanced model initialization with additional parameters for improved functionality. - Introduced functions for neighbor list management and region transformations in molecular simulations. - Added new methods for handling atomic virial calculations in model predictions. - New functions for transforming model outputs to accommodate local and ghost atoms. - **Bug Fixes** - Improved error handling in model serialization and evaluation processes. - **Tests** - Added comprehensive unit tests for new functionalities, ensuring consistent behavior across different scenarios, including tests for neighbor list construction and region transformations. - **Chores** - Updated testing workflow for better organization and efficiency. - Modified dependency management and linting configurations in `pyproject.toml`. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 023bb9c commit dcbf607

File tree

17 files changed

+997
-37
lines changed

17 files changed

+997
-37
lines changed

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
env:
5959
NUM_WORKERS: 0
6060
- name: Test TF2 eager mode
61-
run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0
61+
run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/jax2tf_tests --durations=0
6262
env:
6363
NUM_WORKERS: 0
6464
DP_TEST_TF2_ONLY: 1

deepmd/jax/infer/deep_eval.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ def __init__(
9797
stablehlo_atomic_virial=model_data["@variables"][
9898
"stablehlo_atomic_virial"
9999
].tobytes(),
100+
stablehlo_no_ghost=model_data["@variables"][
101+
"stablehlo_no_ghost"
102+
].tobytes(),
103+
stablehlo_atomic_virial_no_ghost=model_data["@variables"][
104+
"stablehlo_atomic_virial_no_ghost"
105+
].tobytes(),
100106
model_def_script=model_data["model_def_script"],
101107
**model_data["constants"],
102108
)

deepmd/jax/jax2tf/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import tensorflow as tf
3+
import tensorflow.experimental.numpy as tnp
34

45
if not tf.executing_eagerly():
56
# TF disallow temporary eager execution
@@ -9,3 +10,5 @@
910
"If you are converting a model between different backends, "
1011
"considering converting to the `.dp` format first."
1112
)
13+
14+
tnp.experimental_enable_numpy_behavior()

deepmd/jax/jax2tf/make_model.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Callable,
4+
)
5+
6+
import tensorflow as tf
7+
import tensorflow.experimental.numpy as tnp
8+
9+
from deepmd.dpmodel.output_def import (
10+
ModelOutputDef,
11+
)
12+
from deepmd.jax.jax2tf.nlist import (
13+
build_neighbor_list,
14+
extend_coord_with_ghosts,
15+
)
16+
from deepmd.jax.jax2tf.region import (
17+
normalize_coord,
18+
)
19+
from deepmd.jax.jax2tf.transform_output import (
20+
communicate_extended_output,
21+
)
22+
23+
24+
def model_call_from_call_lower(
25+
*, # enforce keyword-only arguments
26+
call_lower: Callable[
27+
[
28+
tnp.ndarray,
29+
tnp.ndarray,
30+
tnp.ndarray,
31+
tnp.ndarray,
32+
tnp.ndarray,
33+
bool,
34+
],
35+
dict[str, tnp.ndarray],
36+
],
37+
rcut: float,
38+
sel: list[int],
39+
mixed_types: bool,
40+
model_output_def: ModelOutputDef,
41+
coord: tnp.ndarray,
42+
atype: tnp.ndarray,
43+
box: tnp.ndarray,
44+
fparam: tnp.ndarray,
45+
aparam: tnp.ndarray,
46+
do_atomic_virial: bool = False,
47+
):
48+
"""Return model prediction from lower interface.
49+
50+
Parameters
51+
----------
52+
coord
53+
The coordinates of the atoms.
54+
shape: nf x (nloc x 3)
55+
atype
56+
The type of atoms. shape: nf x nloc
57+
box
58+
The simulation box. shape: nf x 9
59+
fparam
60+
frame parameter. nf x ndf
61+
aparam
62+
atomic parameter. nf x nloc x nda
63+
do_atomic_virial
64+
If calculate the atomic virial.
65+
66+
Returns
67+
-------
68+
ret_dict
69+
The result dict of type dict[str,tnp.ndarray].
70+
The keys are defined by the `ModelOutputDef`.
71+
72+
"""
73+
atype_shape = tf.shape(atype)
74+
nframes, nloc = atype_shape[0], atype_shape[1]
75+
cc, bb, fp, ap = coord, box, fparam, aparam
76+
del coord, box, fparam, aparam
77+
if tf.shape(bb)[-1] != 0:
78+
coord_normalized = normalize_coord(
79+
cc.reshape(nframes, nloc, 3),
80+
bb.reshape(nframes, 3, 3),
81+
)
82+
else:
83+
coord_normalized = cc
84+
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
85+
coord_normalized, atype, bb, rcut
86+
)
87+
nlist = build_neighbor_list(
88+
extended_coord,
89+
extended_atype,
90+
nloc,
91+
rcut,
92+
sel,
93+
distinguish_types=not mixed_types,
94+
)
95+
extended_coord = extended_coord.reshape(nframes, -1, 3)
96+
model_predict_lower = call_lower(
97+
extended_coord,
98+
extended_atype,
99+
nlist,
100+
mapping,
101+
fparam=fp,
102+
aparam=ap,
103+
)
104+
model_predict = communicate_extended_output(
105+
model_predict_lower,
106+
model_output_def,
107+
mapping,
108+
do_atomic_virial=do_atomic_virial,
109+
)
110+
return model_predict

deepmd/jax/jax2tf/nlist.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Union,
4+
)
5+
6+
import tensorflow as tf
7+
import tensorflow.experimental.numpy as tnp
8+
9+
from .region import (
10+
to_face_distance,
11+
)
12+
13+
14+
## translated from torch implementation by chatgpt
15+
def build_neighbor_list(
16+
coord: tnp.ndarray,
17+
atype: tnp.ndarray,
18+
nloc: int,
19+
rcut: float,
20+
sel: Union[int, list[int]],
21+
distinguish_types: bool = True,
22+
) -> tnp.ndarray:
23+
"""Build neighbor list for a single frame. keeps nsel neighbors.
24+
25+
Parameters
26+
----------
27+
coord : tnp.ndarray
28+
exptended coordinates of shape [batch_size, nall x 3]
29+
atype : tnp.ndarray
30+
extended atomic types of shape [batch_size, nall]
31+
type < 0 the atom is treat as virtual atoms.
32+
nloc : int
33+
number of local atoms.
34+
rcut : float
35+
cut-off radius
36+
sel : int or list[int]
37+
maximal number of neighbors (of each type).
38+
if distinguish_types==True, nsel should be list and
39+
the length of nsel should be equal to number of
40+
types.
41+
distinguish_types : bool
42+
distinguish different types.
43+
44+
Returns
45+
-------
46+
neighbor_list : tnp.ndarray
47+
Neighbor list of shape [batch_size, nloc, nsel], the neighbors
48+
are stored in an ascending order. If the number of
49+
neighbors is less than nsel, the positions are masked
50+
with -1. The neighbor list of an atom looks like
51+
|------ nsel ------|
52+
xx xx xx xx -1 -1 -1
53+
if distinguish_types==True and we have two types
54+
|---- nsel[0] -----| |---- nsel[1] -----|
55+
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
56+
For virtual atoms all neighboring positions are filled with -1.
57+
58+
"""
59+
batch_size = tf.shape(coord)[0]
60+
coord = tnp.reshape(coord, (batch_size, -1))
61+
nall = tf.shape(coord)[1] // 3
62+
# fill virtual atoms with large coords so they are not neighbors of any
63+
# real atom.
64+
if tf.size(coord) > 0:
65+
xmax = tnp.max(coord) + 2.0 * rcut
66+
else:
67+
xmax = tf.cast(2.0 * rcut, coord.dtype)
68+
# nf x nall
69+
is_vir = atype < 0
70+
coord1 = tnp.where(
71+
is_vir[:, :, None], xmax, tnp.reshape(coord, (batch_size, nall, 3))
72+
)
73+
coord1 = tnp.reshape(coord1, (batch_size, nall * 3))
74+
if isinstance(sel, int):
75+
sel = [sel]
76+
nsel = sum(sel)
77+
coord0 = coord1[:, : nloc * 3]
78+
diff = (
79+
tnp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :]
80+
- tnp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :]
81+
)
82+
rr = tf.linalg.norm(diff, axis=-1)
83+
# if central atom has two zero distances, sorting sometimes can not exclude itself
84+
rr -= tf.eye(nloc, nall, dtype=diff.dtype)[tnp.newaxis, :, :]
85+
nlist = tnp.argsort(rr, axis=-1)
86+
rr = tnp.sort(rr, axis=-1)
87+
rr = rr[:, :, 1:]
88+
nlist = nlist[:, :, 1:]
89+
nnei = tf.shape(rr)[2]
90+
if nsel <= nnei:
91+
rr = rr[:, :, :nsel]
92+
nlist = nlist[:, :, :nsel]
93+
else:
94+
rr = tnp.concatenate(
95+
[rr, tnp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut],
96+
axis=-1,
97+
)
98+
nlist = tnp.concatenate(
99+
[nlist, tnp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)],
100+
axis=-1,
101+
)
102+
nlist = tnp.where(
103+
tnp.logical_or((rr > rcut), is_vir[:, :nloc, None]),
104+
tnp.full_like(nlist, -1),
105+
nlist,
106+
)
107+
108+
if distinguish_types:
109+
return nlist_distinguish_types(nlist, atype, sel)
110+
else:
111+
return nlist
112+
113+
114+
def nlist_distinguish_types(
115+
nlist: tnp.ndarray,
116+
atype: tnp.ndarray,
117+
sel: list[int],
118+
):
119+
"""Given a nlist that does not distinguish atom types, return a nlist that
120+
distinguish atom types.
121+
122+
"""
123+
nloc = tf.shape(nlist)[1]
124+
ret_nlist = []
125+
tmp_atype = tnp.tile(atype[:, None, :], (1, nloc, 1))
126+
mask = nlist == -1
127+
tnlist_0 = tnp.where(mask, tnp.zeros_like(nlist), nlist)
128+
tnlist = tnp.take_along_axis(tmp_atype, tnlist_0, axis=2)
129+
tnlist = tnp.where(mask, tnp.full_like(tnlist, -1), tnlist)
130+
for ii, ss in enumerate(sel):
131+
pick_mask = tf.cast(tnlist == ii, tnp.int32)
132+
sorted_indices = tnp.argsort(-pick_mask, kind="stable", axis=-1)
133+
pick_mask_sorted = -tnp.sort(-pick_mask, axis=-1)
134+
inlist = tnp.take_along_axis(nlist, sorted_indices, axis=2)
135+
inlist = tnp.where(
136+
~tf.cast(pick_mask_sorted, tf.bool), tnp.full_like(inlist, -1), inlist
137+
)
138+
ret_nlist.append(inlist[..., :ss])
139+
ret = tf.concat(ret_nlist, axis=-1)
140+
return ret
141+
142+
143+
def tf_outer(a, b):
144+
return tf.einsum("i,j->ij", a, b)
145+
146+
147+
## translated from torch implementation by chatgpt
148+
def extend_coord_with_ghosts(
149+
coord: tnp.ndarray,
150+
atype: tnp.ndarray,
151+
cell: tnp.ndarray,
152+
rcut: float,
153+
):
154+
"""Extend the coordinates of the atoms by appending peridoc images.
155+
The number of images is large enough to ensure all the neighbors
156+
within rcut are appended.
157+
158+
Parameters
159+
----------
160+
coord : tnp.ndarray
161+
original coordinates of shape [-1, nloc*3].
162+
atype : tnp.ndarray
163+
atom type of shape [-1, nloc].
164+
cell : tnp.ndarray
165+
simulation cell tensor of shape [-1, 9].
166+
rcut : float
167+
the cutoff radius
168+
169+
Returns
170+
-------
171+
extended_coord: tnp.ndarray
172+
extended coordinates of shape [-1, nall*3].
173+
extended_atype: tnp.ndarray
174+
extended atom type of shape [-1, nall].
175+
index_mapping: tnp.ndarray
176+
mapping extended index to the local index
177+
178+
"""
179+
atype_shape = tf.shape(atype)
180+
nf, nloc = atype_shape[0], atype_shape[1]
181+
# int64 for index
182+
aidx = tf.range(nloc, dtype=tnp.int64)
183+
aidx = tnp.tile(aidx[tnp.newaxis, :], (nf, 1))
184+
if tf.shape(cell)[-1] == 0:
185+
nall = nloc
186+
extend_coord = coord
187+
extend_atype = atype
188+
extend_aidx = aidx
189+
else:
190+
coord = tnp.reshape(coord, (nf, nloc, 3))
191+
cell = tnp.reshape(cell, (nf, 3, 3))
192+
to_face = to_face_distance(cell)
193+
nbuff = tf.cast(tnp.ceil(rcut / to_face), tnp.int64)
194+
nbuff = tnp.max(nbuff, axis=0)
195+
xi = tf.range(-nbuff[0], nbuff[0] + 1, 1, dtype=tnp.int64)
196+
yi = tf.range(-nbuff[1], nbuff[1] + 1, 1, dtype=tnp.int64)
197+
zi = tf.range(-nbuff[2], nbuff[2] + 1, 1, dtype=tnp.int64)
198+
xyz = tf_outer(xi, tnp.asarray([1, 0, 0]))[:, tnp.newaxis, tnp.newaxis, :]
199+
xyz = xyz + tf_outer(yi, tnp.asarray([0, 1, 0]))[tnp.newaxis, :, tnp.newaxis, :]
200+
xyz = xyz + tf_outer(zi, tnp.asarray([0, 0, 1]))[tnp.newaxis, tnp.newaxis, :, :]
201+
xyz = tnp.reshape(xyz, (-1, 3))
202+
xyz = tf.cast(xyz, coord.dtype)
203+
shift_idx = tnp.take(xyz, tnp.argsort(tf.linalg.norm(xyz, axis=1)), axis=0)
204+
ns = tf.shape(shift_idx)[0]
205+
nall = ns * nloc
206+
shift_vec = tnp.einsum("sd,fdk->fsk", shift_idx, cell)
207+
# shift_vec = tnp.tensordot(shift_idx, cell, axes=([1], [1]))
208+
# shift_vec = tnp.transpose(shift_vec, (1, 0, 2))
209+
extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :]
210+
extend_atype = tnp.tile(atype[:, :, tnp.newaxis], (1, ns, 1))
211+
extend_aidx = tnp.tile(aidx[:, :, tnp.newaxis], (1, ns, 1))
212+
213+
return (
214+
tnp.reshape(extend_coord, (nf, nall * 3)),
215+
tnp.reshape(extend_atype, (nf, nall)),
216+
tnp.reshape(extend_aidx, (nf, nall)),
217+
)

0 commit comments

Comments
 (0)