Skip to content

Commit 64fbde6

Browse files
committed
deal with case without fitting net
1 parent 205e506 commit 64fbde6

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

deepmd/pt/model/model/make_model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,13 @@ def forward_common_lower(
297297
comm_dict=comm_dict,
298298
)
299299
# add weight to atomic_output
300-
kw = self.atomic_model.fitting_net.var_name
301-
if atomic_weight is not None:
302-
# atomic_weight: nf x nloc x dim
303-
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
304-
[atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
305-
)
300+
if hasattr(self.atomic_model, "fitting_net"):
301+
kw = self.atomic_model.fitting_net.var_name
302+
if atomic_weight is not None:
303+
# atomic_weight: nf x nloc x dim
304+
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
305+
[atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
306+
)
306307
model_predict = fit_output_to_model_output(
307308
atomic_ret,
308309
self.atomic_output_def(),

0 commit comments

Comments
 (0)