File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -297,12 +297,13 @@ def forward_common_lower(
297297 comm_dict = comm_dict ,
298298 )
299299 # add weight to atomic_output
300- kw = self .atomic_model .fitting_net .var_name
301- if atomic_weight is not None :
302- # atomic_weight: nf x nloc x dim
303- atomic_ret [kw ] = atomic_ret [kw ] * atomic_weight .view (
304- [atomic_ret [kw ].shape [0 ], atomic_ret [kw ].shape [1 ], - 1 ]
305- )
300+ if hasattr (self .atomic_model , "fitting_net" ):
301+ kw = self .atomic_model .fitting_net .var_name
302+ if atomic_weight is not None :
303+ # atomic_weight: nf x nloc x dim
304+ atomic_ret [kw ] = atomic_ret [kw ] * atomic_weight .view (
305+ [atomic_ret [kw ].shape [0 ], atomic_ret [kw ].shape [1 ], - 1 ]
306+ )
306307 model_predict = fit_output_to_model_output (
307308 atomic_ret ,
308309 self .atomic_output_def (),
You can’t perform that action at this time.
0 commit comments