@@ -176,10 +176,11 @@ def init_subclass_params(sub_data, sub_class):
176176 self .use_loc_mapping = use_loc_mapping
177177 self .use_tebd_bias = use_tebd_bias
178178 self .type_map = type_map
179- self .register_buffer (
180- "buffer_type_map" ,
181- paddle .to_tensor ([ord (c ) for c in " " .join (self .type_map )]),
182- )
179+ if type_map is not None :
180+ self .register_buffer (
181+ "buffer_type_map" ,
182+ paddle .to_tensor ([ord (c ) for c in " " .join (self .type_map )]),
183+ )
183184 self .tebd_dim = self .repflow_args .n_dim
184185 self .type_embedding = TypeEmbedNet (
185186 ntypes ,
@@ -222,24 +223,30 @@ def init_subclass_params(sub_data, sub_class):
222223
223224 def get_rcut (self ) -> float :
224225 """Returns the cut-off radius."""
225- if paddle .in_dynamic_mode ():
226- return self .rcut
227- return self .repflows .get_rcut ()
226+ return self .rcut
228227
229228 def get_rcut_smth (self ) -> float :
230229 """Returns the radius where the neighbor information starts to smoothly decay to 0."""
231- if paddle .in_dynamic_mode ():
232- return self .rcut_smth
233- return self .repflows .get_rcut_smth ()
230+ return self .rcut_smth
231+
232+ def get_buffer_rcut (self ) -> paddle .Tensor :
233+ """Returns the cut-off radius."""
234+ return self .repflows .get_buffer_rcut ()
235+
236+ def get_buffer_rcut_smth (self ) -> paddle .Tensor :
237+ """Returns the radius where the neighbor information starts to smoothly decay to 0."""
238+ return self .repflows .get_buffer_rcut_smth ()
234239
235240 def get_nsel (self ) -> int :
236241 """Returns the number of selected atoms in the cut-off radius."""
237242 return sum (self .sel )
238243
239244 def get_sel (self ) -> list [int ]:
240245 """Returns the number of selected atoms for each type."""
241- if paddle .in_dynamic_mode ():
242- return self .sel
246+ return self .sel
247+
248+ def get_buffer_sel (self ) -> paddle .Tensor :
249+ """Returns the number of selected atoms for each type."""
243250 return self .repflows .get_sel ()
244251
245252 def get_ntypes (self ) -> int :
@@ -248,8 +255,10 @@ def get_ntypes(self) -> int:
248255
249256 def get_type_map (self ) -> list [str ]:
250257 """Get the name to each type of atoms."""
251- if paddle .in_dynamic_mode ():
252- return self .type_map
258+ return self .type_map
259+
260+ def get_buffer_type_map (self ) -> paddle .Tensor :
261+ """Get the name to each type of atoms."""
253262 return self .buffer_type_map
254263
255264 def get_dim_out (self ) -> int :
0 commit comments