@@ -64,9 +64,9 @@ class BaseAtomicModel(paddle.nn.Layer, BaseAtomicModel_):
6464 of the atomic model. Implemented by removing the pairs from the nlist.
6565 rcond : float, optional
6666 The condition number for the regression of atomic energy.
67- preset_out_bias : Dict [str, list[Optional[paddle.Tensor ]]], optional
67+ preset_out_bias : dict [str, list[Optional[np.ndarray ]]], optional
6868 Specifying atomic energy contribution in vacuum. Given by key:value pairs.
69- The value is a list specifying the bias. the elements can be None or np.array of output shape.
69+ The value is a list specifying the bias. the elements can be None or np.ndarray of output shape.
7070 For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.]
7171 The `set_davg_zero` key in the descriptor should be set.
7272
@@ -114,15 +114,15 @@ def init_out_stat(self) -> None:
114114 def set_out_bias (self , out_bias : paddle .Tensor ) -> None :
115115 self .out_bias = out_bias
116116
117- def __setitem__ (self , key , value ) -> None :
117+ def __setitem__ (self , key : str , value : paddle . Tensor ) -> None :
118118 if key in ["out_bias" ]:
119119 self .out_bias = value
120120 elif key in ["out_std" ]:
121121 self .out_std = value
122122 else :
123123 raise KeyError (key )
124124
125- def __getitem__ (self , key ) :
125+ def __getitem__ (self , key : str ) -> paddle . Tensor :
126126 if key in ["out_bias" ]:
127127 return self .out_bias
128128 elif key in ["out_std" ]:
@@ -146,6 +146,10 @@ def get_intensive(self) -> bool:
146146 """Whether the fitting property is intensive."""
147147 return False
148148
149+ def has_default_fparam (self ) -> bool :
150+ """Check if the model has default frame parameters."""
151+ return False
152+
149153 def reinit_atom_exclude (
150154 self ,
151155 exclude_types : Optional [list [int ]] = None ,
@@ -271,7 +275,6 @@ def forward_common_atomic(
271275 comm_dict = comm_dict ,
272276 )
273277 ret_dict = self .apply_out_stat (ret_dict , atype )
274-
275278 # nf x nloc
276279 atom_mask = ext_atom_mask [:, :nloc ].astype (paddle .int32 )
277280 if self .atom_excl is not None :
@@ -284,10 +287,10 @@ def forward_common_atomic(
284287 out_shape2 *= ss
285288 ret_dict [kk ] = (
286289 ret_dict [kk ].reshape ([out_shape [0 ], out_shape [1 ], out_shape2 ])
287- * atom_mask . unsqueeze ( 2 ) .astype (ret_dict [kk ].dtype )
290+ * atom_mask [:, :, None ] .astype (ret_dict [kk ].dtype )
288291 ).reshape (out_shape )
289292 ret_dict ["mask" ] = atom_mask
290-
293+ # raise
291294 return ret_dict
292295
293296 def forward (
@@ -311,7 +314,9 @@ def forward(
311314 )
312315
313316 def change_type_map (
314- self , type_map : list [str ], model_with_new_type_stat = None
317+ self ,
318+ type_map : list [str ],
319+ model_with_new_type_stat : Optional ["BaseAtomicModel" ] = None ,
315320 ) -> None :
316321 """Change the type related params to new ones, according to `type_map` and the original one in the model.
317322 If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
@@ -378,21 +383,25 @@ def compute_or_load_stat(
378383 self ,
379384 merged : Union [Callable [[], list [dict ]], list [dict ]],
380385 stat_file_path : Optional [DPPath ] = None ,
386+ compute_or_load_out_stat : bool = True ,
381387 ) -> NoReturn :
382388 """
383- Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
389+ Compute or load the statistics parameters of the model,
390+ such as mean and standard deviation of descriptors or the energy bias of the fitting net.
391+ When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
392+ and saved in the `stat_file_path`(s).
393+ When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
394+ and load the calculated statistics parameters.
384395
385396 Parameters
386397 ----------
387- merged : Union[Callable[[], list[dict]], list[dict]]
388- - list[dict]: A list of data samples from various data systems.
389- Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
390- originating from the `i`-th data system.
391- - Callable[[], list[dict]]: A lazy function that returns data samples in the above format
392- only when needed. Since the sampling process can be slow and memory-intensive,
393- the lazy function helps by only sampling once.
394- stat_file_path : Optional[DPPath]
395- The path to the stat file.
398+ merged
399+ The lazy sampled function to get data frames from different data systems.
400+ stat_file_path
401+ The dictionary of paths to the statistics files.
402+ compute_or_load_out_stat : bool
403+ Whether to compute the output statistics.
404+ If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
396405
397406 """
398407 raise NotImplementedError
@@ -428,7 +437,7 @@ def apply_out_stat(
428437 self ,
429438 ret : dict [str , paddle .Tensor ],
430439 atype : paddle .Tensor ,
431- ):
440+ ) -> dict [ str , paddle . Tensor ] :
432441 """Apply the stat to each atomic output.
433442 The developer may override the method to define how the bias is applied
434443 to the atomic output of the model.
@@ -449,9 +458,9 @@ def apply_out_stat(
449458
450459 def change_out_bias (
451460 self ,
452- sample_merged ,
461+ sample_merged : Union [ Callable [[], list [ dict ]], list [ dict ]] ,
453462 stat_file_path : Optional [DPPath ] = None ,
454- bias_adjust_mode = "change-by-statistic" ,
463+ bias_adjust_mode : str = "change-by-statistic" ,
455464 ) -> None :
456465 """Change the output bias according to the input data and the pretrained model.
457466
@@ -501,7 +510,13 @@ def change_out_bias(
501510 def _get_forward_wrapper_func (self ) -> Callable [..., paddle .Tensor ]:
502511 """Get a forward wrapper of the atomic model for output bias calculation."""
503512
504- def model_forward (coord , atype , box , fparam = None , aparam = None ):
513+ def model_forward (
514+ coord : paddle .Tensor ,
515+ atype : paddle .Tensor ,
516+ box : Optional [paddle .Tensor ],
517+ fparam : Optional [paddle .Tensor ] = None ,
518+ aparam : Optional [paddle .Tensor ] = None ,
519+ ) -> dict [str , paddle .Tensor ]:
505520 with (
506521 paddle .no_grad ()
507522 ): # it's essential for pure paddle forward function to use auto_batchsize
@@ -530,7 +545,7 @@ def model_forward(coord, atype, box, fparam=None, aparam=None):
530545
531546 return model_forward
532547
533- def _default_bias (self ):
548+ def _default_bias (self ) -> paddle . Tensor :
534549 ntypes = self .get_ntypes ()
535550 return paddle .zeros ([self .n_out , ntypes , self .max_out_size ], dtype = dtype ).to (
536551 device = device
0 commit comments