Skip to content

Commit 32a64e1

Browse files
committed
move atomic_weight to atomic model level
1 parent 3763d8e commit 32a64e1

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def forward_common_atomic(
203203
fparam: Optional[torch.Tensor] = None,
204204
aparam: Optional[torch.Tensor] = None,
205205
comm_dict: Optional[dict[str, torch.Tensor]] = None,
206+
atomic_weight: Optional[torch.Tensor] = None,
206207
) -> dict[str, torch.Tensor]:
207208
"""Common interface for atomic inference.
208209
@@ -271,6 +272,10 @@ def forward_common_atomic(
271272
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
272273
* atom_mask[:, :, None]
273274
).view(out_shape)
275+
if atomic_weight is not None:
276+
ret_dict[kk] = ret_dict[kk] * atomic_weight.view(
277+
[out_shape[0], out_shape[1], -1]
278+
)
274279
ret_dict["mask"] = atom_mask
275280

276281
return ret_dict
@@ -284,6 +289,7 @@ def forward(
284289
fparam: Optional[torch.Tensor] = None,
285290
aparam: Optional[torch.Tensor] = None,
286291
comm_dict: Optional[dict[str, torch.Tensor]] = None,
292+
atomic_weight: Optional[torch.Tensor] = None,
287293
) -> dict[str, torch.Tensor]:
288294
return self.forward_common_atomic(
289295
extended_coord,
@@ -293,6 +299,7 @@ def forward(
293299
fparam=fparam,
294300
aparam=aparam,
295301
comm_dict=comm_dict,
302+
atomic_weight=atomic_weight,
296303
)
297304

298305
def change_type_map(

deepmd/pt/model/model/make_model.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,8 @@ def forward_common_lower(
295295
fparam=fp,
296296
aparam=ap,
297297
comm_dict=comm_dict,
298+
atomic_weight=atomic_weight,
298299
)
299-
# add weight to atomic_output
300-
if hasattr(self.atomic_model, "fitting_net"):
301-
if hasattr(self.atomic_model.fitting_net, "var_name"):
302-
kw = self.atomic_model.fitting_net.var_name
303-
if atomic_weight is not None:
304-
# atomic_weight: nf x nloc x dim
305-
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
306-
[atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
307-
)
308300
model_predict = fit_output_to_model_output(
309301
atomic_ret,
310302
self.atomic_output_def(),

0 commit comments

Comments
 (0)