We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b5356d9 commit 205e506Copy full SHA for 205e506
deepmd/pt/model/model/make_model.py
@@ -297,11 +297,11 @@ def forward_common_lower(
297
comm_dict=comm_dict,
298
)
299
# add weight to atomic_output
300
- kw = next(iter(self.atomic_output_def().var_defs.keys()))
+ kw = self.atomic_model.fitting_net.var_name
301
if atomic_weight is not None:
302
# atomic_weight: nf x nloc x dim
303
- atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
304
- *atomic_ret[kw].shape[:-1], -1
+ atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
+ [atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
305
306
model_predict = fit_output_to_model_output(
307
atomic_ret,
0 commit comments