Skip to content

Commit 8a5cc16

Browse files
authored
Merge branch 'devel' into devel-use_aparam_as_mask
2 parents 7c1f863 + fa61d69 commit 8a5cc16

File tree

5 files changed

+155
-15
lines changed

5 files changed

+155
-15
lines changed

deepmd/dpmodel/descriptor/se_r.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
Union,
77
)
88

9+
import array_api_compat
910
import numpy as np
1011

1112
from deepmd.dpmodel import (
1213
DEFAULT_PRECISION,
1314
PRECISION_DICT,
1415
NativeOP,
1516
)
17+
from deepmd.dpmodel.common import (
18+
get_xp_precision,
19+
to_numpy_array,
20+
)
1621
from deepmd.dpmodel.utils import (
1722
EmbeddingNet,
1823
EnvMat,
@@ -25,9 +30,6 @@
2530
from deepmd.dpmodel.utils.update_sel import (
2631
UpdateSel,
2732
)
28-
from deepmd.env import (
29-
GLOBAL_NP_FLOAT_PRECISION,
30-
)
3133
from deepmd.utils.data_system import (
3234
DeepmdDataSystem,
3335
)
@@ -144,31 +146,33 @@ def __init__(
144146
self.env_protection = env_protection
145147

146148
in_dim = 1 # not considiering type embedding
147-
self.embeddings = NetworkCollection(
149+
embeddings = NetworkCollection(
148150
ntypes=self.ntypes,
149151
ndim=(1 if self.type_one_side else 2),
150152
network_type="embedding_network",
151153
)
152154
if not self.type_one_side:
153155
raise NotImplementedError("type_one_side == False not implemented")
154156
for ii in range(self.ntypes):
155-
self.embeddings[(ii,)] = EmbeddingNet(
157+
embeddings[(ii,)] = EmbeddingNet(
156158
in_dim,
157159
self.neuron,
158160
self.activation_function,
159161
self.resnet_dt,
160162
self.precision,
161163
seed=child_seed(seed, ii),
162164
)
165+
self.embeddings = embeddings
163166
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
164-
self.nnei = np.sum(self.sel)
167+
self.nnei = np.sum(self.sel).item()
165168
self.davg = np.zeros(
166169
[self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision]
167170
)
168171
self.dstd = np.ones(
169172
[self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision]
170173
)
171174
self.orig_sel = self.sel
175+
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]
172176

173177
def __setitem__(self, key, value):
174178
if key in ("avg", "data_avg", "davg"):
@@ -279,8 +283,9 @@ def cal_g(
279283
ss,
280284
ll,
281285
):
286+
xp = array_api_compat.array_namespace(ss)
282287
nf, nloc, nnei = ss.shape[0:3]
283-
ss = ss.reshape(nf, nloc, nnei, 1)
288+
ss = xp.reshape(ss, (nf, nloc, nnei, 1))
284289
# nf x nloc x nnei x ng
285290
gg = self.embeddings[(ll,)].call(ss)
286291
return gg
@@ -321,29 +326,34 @@ def call(
321326
sw
322327
The smooth switch function.
323328
"""
329+
xp = array_api_compat.array_namespace(coord_ext)
324330
del mapping
325331
# nf x nloc x nnei x 1
326332
rr, diff, ww = self.env_mat.call(
327333
coord_ext, atype_ext, nlist, self.davg, self.dstd, True
328334
)
329335
nf, nloc, nnei, _ = rr.shape
330-
sec = np.append([0], np.cumsum(self.sel))
336+
sec = self.sel_cumsum
331337

332338
ng = self.neuron[-1]
333-
xyz_scatter = np.zeros([nf, nloc, ng], dtype=PRECISION_DICT[self.precision])
339+
xyz_scatter = xp.zeros(
340+
[nf, nloc, ng], dtype=get_xp_precision(xp, self.precision)
341+
)
334342
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
343+
rr = xp.astype(rr, xyz_scatter.dtype)
335344
for tt in range(self.ntypes):
336345
mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]]
337346
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
338-
tr = tr * mm[:, :, :, None]
347+
tr = tr * xp.astype(mm[:, :, :, None], tr.dtype)
339348
gg = self.cal_g(tr, tt)
340-
gg = np.mean(gg, axis=2)
349+
gg = xp.mean(gg, axis=2)
341350
# nf x nloc x ng x 1
342351
xyz_scatter += gg * (self.sel[tt] / self.nnei)
343352

344353
res_rescale = 1.0 / 5.0
345354
res = xyz_scatter * res_rescale
346-
res = res.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)
355+
res = xp.reshape(res, (nf, nloc, ng))
356+
res = xp.astype(res, get_xp_precision(xp, "global"))
347357
return res, None, None, None, ww
348358

349359
def serialize(self) -> dict:
@@ -369,8 +379,8 @@ def serialize(self) -> dict:
369379
"env_mat": self.env_mat.serialize(),
370380
"embeddings": self.embeddings.serialize(),
371381
"@variables": {
372-
"davg": self.davg,
373-
"dstd": self.dstd,
382+
"davg": to_numpy_array(self.davg),
383+
"dstd": to_numpy_array(self.dstd),
374384
},
375385
"type_map": self.type_map,
376386
}

deepmd/jax/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
from deepmd.jax.descriptor.se_e2_a import (
66
DescrptSeA,
77
)
8+
from deepmd.jax.descriptor.se_e2_r import (
9+
DescrptSeR,
10+
)
811

912
__all__ = [
1013
"DescrptSeA",
14+
"DescrptSeR",
1115
"DescrptDPA1",
1216
]

deepmd/jax/descriptor/se_e2_r.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
7+
from deepmd.jax.common import (
8+
ArrayAPIVariable,
9+
flax_module,
10+
to_jax_array,
11+
)
12+
from deepmd.jax.descriptor.base_descriptor import (
13+
BaseDescriptor,
14+
)
15+
from deepmd.jax.utils.exclude_mask import (
16+
PairExcludeMask,
17+
)
18+
from deepmd.jax.utils.network import (
19+
NetworkCollection,
20+
)
21+
22+
23+
@BaseDescriptor.register("se_e2_r")
24+
@BaseDescriptor.register("se_r")
25+
@flax_module
26+
class DescrptSeR(DescrptSeRDP):
27+
def __setattr__(self, name: str, value: Any) -> None:
28+
if name in {"dstd", "davg"}:
29+
value = to_jax_array(value)
30+
if value is not None:
31+
value = ArrayAPIVariable(value)
32+
elif name in {"embeddings"}:
33+
if value is not None:
34+
value = NetworkCollection.deserialize(value.serialize())
35+
elif name == "env_mat":
36+
# env_mat doesn't store any value
37+
pass
38+
elif name == "emask":
39+
value = PairExcludeMask(value.ntypes, value.exclude_types)
40+
41+
return super().__setattr__(name, value)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
7+
8+
from ..common import (
9+
to_array_api_strict_array,
10+
)
11+
from ..utils.exclude_mask import (
12+
PairExcludeMask,
13+
)
14+
from ..utils.network import (
15+
NetworkCollection,
16+
)
17+
18+
19+
class DescrptSeR(DescrptSeRDP):
20+
def __setattr__(self, name: str, value: Any) -> None:
21+
if name in {"dstd", "davg"}:
22+
value = to_array_api_strict_array(value)
23+
elif name in {"embeddings"}:
24+
if value is not None:
25+
value = NetworkCollection.deserialize(value.serialize())
26+
elif name == "env_mat":
27+
# env_mat doesn't store any value
28+
pass
29+
elif name == "emask":
30+
value = PairExcludeMask(value.ntypes, value.exclude_types)
31+
32+
return super().__setattr__(name, value)

source/tests/consistent/descriptor/test_se_r.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313

1414
from ..common import (
15+
INSTALLED_ARRAY_API_STRICT,
16+
INSTALLED_JAX,
1517
INSTALLED_PT,
1618
INSTALLED_TF,
1719
CommonTest,
@@ -33,14 +35,25 @@
3335
descrpt_se_r_args,
3436
)
3537

38+
if INSTALLED_JAX:
39+
from deepmd.jax.descriptor.se_e2_r import DescrptSeR as DescrptSeRJAX
40+
else:
41+
DescrptSeRJAX = None
42+
if INSTALLED_ARRAY_API_STRICT:
43+
from ...array_api_strict.descriptor.se_e2_r import (
44+
DescrptSeR as DescrptSeRArrayAPIStrict,
45+
)
46+
else:
47+
DescrptSeRArrayAPIStrict = None
48+
3649

3750
@parameterized(
3851
(True, False), # resnet_dt
3952
(True, False), # type_one_side
4053
([], [[0, 1]]), # excluded_types
4154
("float32", "float64"), # precision
4255
)
43-
class TestSeA(CommonTest, DescriptorTest, unittest.TestCase):
56+
class TestSeR(CommonTest, DescriptorTest, unittest.TestCase):
4457
@property
4558
def data(self) -> dict:
4659
(
@@ -81,9 +94,31 @@ def skip_dp(self) -> bool:
8194
) = self.param
8295
return not type_one_side or CommonTest.skip_dp
8396

97+
@property
98+
def skip_jax(self) -> bool:
99+
(
100+
resnet_dt,
101+
type_one_side,
102+
excluded_types,
103+
precision,
104+
) = self.param
105+
return not type_one_side or not INSTALLED_JAX
106+
107+
@property
108+
def skip_array_api_strict(self) -> bool:
109+
(
110+
resnet_dt,
111+
type_one_side,
112+
excluded_types,
113+
precision,
114+
) = self.param
115+
return not type_one_side or not INSTALLED_ARRAY_API_STRICT
116+
84117
tf_class = DescrptSeRTF
85118
dp_class = DescrptSeRDP
86119
pt_class = DescrptSeRPT
120+
jax_class = DescrptSeRJAX
121+
array_api_strict_class = DescrptSeRArrayAPIStrict
87122
args = descrpt_se_r_args()
88123

89124
def setUp(self):
@@ -148,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
148183
self.box,
149184
)
150185

186+
def eval_jax(self, jax_obj: Any) -> Any:
187+
return self.eval_jax_descriptor(
188+
jax_obj,
189+
self.natoms,
190+
self.coords,
191+
self.atype,
192+
self.box,
193+
)
194+
195+
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
196+
return self.eval_array_api_strict_descriptor(
197+
array_api_strict_obj,
198+
self.natoms,
199+
self.coords,
200+
self.atype,
201+
self.box,
202+
)
203+
151204
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
152205
return (ret[0],)
153206

0 commit comments

Comments
 (0)