File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed
Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -297,12 +297,14 @@ def forward_common_lower(
297297 comm_dict = comm_dict ,
298298 )
299299 # add weight to atomic_output
300- kw = self .atomic_model .fitting_net .var_name
301- if atomic_weight is not None :
302- # atomic_weight: nf x nloc x dim
303- atomic_ret [kw ] = atomic_ret [kw ] * atomic_weight .view (
304- [atomic_ret [kw ].shape [0 ], atomic_ret [kw ].shape [1 ], - 1 ]
305- )
300+ if hasattr (self .atomic_model , "fitting_net" ):
301+ if hasattr (self .atomic_model .fitting_net , "var_name" ):
302+ kw = self .atomic_model .fitting_net .var_name
303+ if atomic_weight is not None :
304+ # atomic_weight: nf x nloc x dim
305+ atomic_ret [kw ] = atomic_ret [kw ] * atomic_weight .view (
306+ [atomic_ret [kw ].shape [0 ], atomic_ret [kw ].shape [1 ], - 1 ]
307+ )
306308 model_predict = fit_output_to_model_output (
307309 atomic_ret ,
308310 self .atomic_output_def (),
You can’t perform that action at this time.
0 commit comments