Skip to content

Commit ad063bd

Browse files
committed
(np) add atomic_weight to atomic model and dipole model; add UT for pt/np
1 parent 32a64e1 commit ad063bd

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def forward_common_atomic(
149149
mapping: Optional[np.ndarray] = None,
150150
fparam: Optional[np.ndarray] = None,
151151
aparam: Optional[np.ndarray] = None,
152+
atomic_weight: Optional[np.ndarray] = None,
152153
) -> dict[str, np.ndarray]:
153154
"""Common interface for atomic inference.
154155
@@ -213,6 +214,11 @@ def forward_common_atomic(
213214
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
214215
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
215216
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
217+
if atomic_weight is not None:
218+
_out_shape = ret_dict[kk].shape
219+
ret_dict[kk] = ret_dict[kk] * atomic_weight.reshape(
220+
[_out_shape[0], _out_shape[1], -1]
221+
)
216222
ret_dict["mask"] = xp.astype(atom_mask, xp.int32)
217223

218224
return ret_dict
@@ -225,6 +231,7 @@ def call(
225231
mapping: Optional[np.ndarray] = None,
226232
fparam: Optional[np.ndarray] = None,
227233
aparam: Optional[np.ndarray] = None,
234+
atomic_weight: Optional[np.ndarray] = None,
228235
) -> dict[str, np.ndarray]:
229236
return self.forward_common_atomic(
230237
extended_coord,
@@ -233,6 +240,7 @@ def call(
233240
mapping=mapping,
234241
fparam=fparam,
235242
aparam=aparam,
243+
atomic_weight=atomic_weight,
236244
)
237245

238246
def serialize(self) -> dict:

deepmd/dpmodel/model/make_model.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def model_call_from_call_lower(
6363
fparam: Optional[np.ndarray] = None,
6464
aparam: Optional[np.ndarray] = None,
6565
do_atomic_virial: bool = False,
66+
atomic_weight: Optional[np.ndarray] = None,
6667
):
6768
"""Return model prediction from lower interface.
6869
@@ -121,6 +122,7 @@ def model_call_from_call_lower(
121122
fparam=fp,
122123
aparam=ap,
123124
do_atomic_virial=do_atomic_virial,
125+
atomic_weight=atomic_weight,
124126
)
125127
model_predict = communicate_extended_output(
126128
model_predict_lower,
@@ -224,6 +226,7 @@ def call(
224226
fparam: Optional[np.ndarray] = None,
225227
aparam: Optional[np.ndarray] = None,
226228
do_atomic_virial: bool = False,
229+
atomic_weight: Optional[np.ndarray] = None,
227230
) -> dict[str, np.ndarray]:
228231
"""Return model prediction.
229232
@@ -250,8 +253,12 @@ def call(
250253
The keys are defined by the `ModelOutputDef`.
251254
252255
"""
253-
cc, bb, fp, ap, input_prec = self.input_type_cast(
254-
coord, box=box, fparam=fparam, aparam=aparam
256+
cc, bb, fp, ap, aw, input_prec = self.input_type_cast(
257+
coord,
258+
box=box,
259+
fparam=fparam,
260+
aparam=aparam,
261+
atomic_weight=atomic_weight,
255262
)
256263
del coord, box, fparam, aparam
257264
model_predict = model_call_from_call_lower(
@@ -266,6 +273,7 @@ def call(
266273
fparam=fp,
267274
aparam=ap,
268275
do_atomic_virial=do_atomic_virial,
276+
atomic_weight=aw,
269277
)
270278
model_predict = self.output_type_cast(model_predict, input_prec)
271279
return model_predict
@@ -279,6 +287,7 @@ def call_lower(
279287
fparam: Optional[np.ndarray] = None,
280288
aparam: Optional[np.ndarray] = None,
281289
do_atomic_virial: bool = False,
290+
atomic_weight: Optional[np.ndarray] = None,
282291
):
283292
"""Return model prediction. Lower interface that takes
284293
extended atomic coordinates and types, nlist, and mapping
@@ -316,8 +325,11 @@ def call_lower(
316325
nlist,
317326
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
318327
)
319-
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
320-
extended_coord, fparam=fparam, aparam=aparam
328+
cc_ext, _, fp, ap, aw, input_prec = self.input_type_cast(
329+
extended_coord,
330+
fparam=fparam,
331+
aparam=aparam,
332+
atomic_weight=atomic_weight,
321333
)
322334
del extended_coord, fparam, aparam
323335
model_predict = self.forward_common_atomic(
@@ -328,6 +340,7 @@ def call_lower(
328340
fparam=fp,
329341
aparam=ap,
330342
do_atomic_virial=do_atomic_virial,
343+
atomic_weight=aw,
331344
)
332345
model_predict = self.output_type_cast(model_predict, input_prec)
333346
return model_predict
@@ -341,6 +354,7 @@ def forward_common_atomic(
341354
fparam: Optional[np.ndarray] = None,
342355
aparam: Optional[np.ndarray] = None,
343356
do_atomic_virial: bool = False,
357+
atomic_weight: Optional[np.ndarray] = None,
344358
):
345359
atomic_ret = self.atomic_model.forward_common_atomic(
346360
extended_coord,
@@ -349,6 +363,7 @@ def forward_common_atomic(
349363
mapping=mapping,
350364
fparam=fparam,
351365
aparam=aparam,
366+
atomic_weight=atomic_weight,
352367
)
353368
return fit_output_to_model_output(
354369
atomic_ret,
@@ -365,11 +380,13 @@ def input_type_cast(
365380
box: Optional[np.ndarray] = None,
366381
fparam: Optional[np.ndarray] = None,
367382
aparam: Optional[np.ndarray] = None,
383+
atomic_weight: Optional[np.ndarray] = None,
368384
) -> tuple[
369385
np.ndarray,
370386
Optional[np.ndarray],
371387
Optional[np.ndarray],
372388
Optional[np.ndarray],
389+
Optional[np.ndarray],
373390
str,
374391
]:
375392
"""Cast the input data to global float type."""
@@ -379,18 +396,19 @@ def input_type_cast(
379396
###
380397
_lst: list[Optional[np.ndarray]] = [
381398
vv.astype(coord.dtype) if vv is not None else None
382-
for vv in [box, fparam, aparam]
399+
for vv in [box, fparam, aparam, atomic_weight]
383400
]
384-
box, fparam, aparam = _lst
401+
box, fparam, aparam, atomic_weight = _lst
385402
if input_prec == RESERVED_PRECISION_DICT[self.global_np_float_precision]:
386-
return coord, box, fparam, aparam, input_prec
403+
return coord, box, fparam, aparam, atomic_weight, input_prec
387404
else:
388405
pp = self.global_np_float_precision
389406
return (
390407
coord.astype(pp),
391408
box.astype(pp) if box is not None else None,
392409
fparam.astype(pp) if fparam is not None else None,
393410
aparam.astype(pp) if aparam is not None else None,
411+
atomic_weight.astype(pp) if atomic_weight is not None else None,
394412
input_prec,
395413
)
396414

source/tests/pt/model/test_dp_atomic_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def test_self_consistency(self) -> None:
7373
to_numpy_array(ret0["energy"]),
7474
to_numpy_array(ret1["energy"]),
7575
)
76+
# add test for atomic_weight
77+
aw = torch.rand([nf, nloc, 1], dtype=dtype, device=env.DEVICE)
78+
ret2 = md0.forward_common_atomic(*args, atomic_weight=aw)
79+
np.testing.assert_allclose(
80+
to_numpy_array(ret0["energy"] * aw.reshape(nf, nloc, -1)),
81+
to_numpy_array(ret2["energy"]),
82+
)
7683

7784
def test_dp_consistency(self) -> None:
7885
nf, nloc, nnei = self.nlist.shape
@@ -101,6 +108,14 @@ def test_dp_consistency(self) -> None:
101108
ret0["energy"],
102109
to_numpy_array(ret1["energy"]),
103110
)
111+
# add test for atomic_weight
112+
aw = torch.rand([nf, nloc, 1], dtype=dtype, device=env.DEVICE)
113+
ret2 = md0.forward_common_atomic(*args0, atomic_weight=to_numpy_array(aw))
114+
ret3 = md1.forward_common_atomic(*args1, atomic_weight=aw)
115+
np.testing.assert_allclose(
116+
ret2["energy"],
117+
to_numpy_array(ret3["energy"]),
118+
)
104119

105120
def test_jit(self) -> None:
106121
nf, nloc, nnei = self.nlist.shape

0 commit comments

Comments
 (0)