Skip to content

Commit 50a8a21

Browse files
committed
deal with case without fitting net
1 parent 205e506 commit 50a8a21

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

deepmd/pt/model/model/make_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,14 @@ def forward_common_lower(
297297
comm_dict=comm_dict,
298298
)
299299
# add weight to atomic_output
300-
kw = self.atomic_model.fitting_net.var_name
301-
if atomic_weight is not None:
302-
# atomic_weight: nf x nloc x dim
303-
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
304-
[atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
305-
)
300+
if hasattr(self.atomic_model, "fitting_net"):
301+
if hasattr(self.atomic_model.fitting_net, "var_name"):
302+
kw = self.atomic_model.fitting_net.var_name
303+
if atomic_weight is not None:
304+
# atomic_weight: nf x nloc x dim
305+
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
306+
[atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
307+
)
306308
model_predict = fit_output_to_model_output(
307309
atomic_ret,
308310
self.atomic_output_def(),

0 commit comments

Comments
 (0)