Skip to content

Commit 205e506

Browse files
committed
fix bug in model frozen
1 parent b5356d9 commit 205e506

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

deepmd/pt/model/model/make_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,11 @@ def forward_common_lower(
297297
comm_dict=comm_dict,
298298
)
299299
# add weight to atomic_output
300-
kw = next(iter(self.atomic_output_def().var_defs.keys()))
300+
kw = self.atomic_model.fitting_net.var_name
301301
if atomic_weight is not None:
302302
# atomic_weight: nf x nloc x dim
303-
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
304-
*atomic_ret[kw].shape[:-1], -1
303+
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
304+
[atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
305305
)
306306
model_predict = fit_output_to_model_output(
307307
atomic_ret,

0 commit comments

Comments
 (0)