File tree Expand file tree Collapse file tree 7 files changed +31
-20
lines changed Expand file tree Collapse file tree 7 files changed +31
-20
lines changed Original file line number Diff line number Diff line change @@ -84,9 +84,11 @@ def __init__(
8484 paddle .nn .Layer .__init__ (self )
8585 BaseAtomicModel_ .__init__ (self )
8686 self .type_map = type_map
87- self .register_buffer (
88- "buffer_type_map" , paddle .to_tensor ([ord (c ) for c in type_map ])
89- )
87+ if type_map is not None :
88+ self .register_buffer (
89+ "buffer_type_map" ,
90+ paddle .to_tensor ([ord (c ) for c in " " .join (type_map )]),
91+ )
9092 self .ntypes = len (self .type_map )
9193 self .register_buffer (
9294 "buffer_ntypes" , paddle .to_tensor (self .ntypes , dtype = "int64" )
Original file line number Diff line number Diff line change @@ -297,9 +297,11 @@ def __init__(
297297 self .use_econf_tebd = use_econf_tebd
298298 self .use_tebd_bias = use_tebd_bias
299299 self .type_map = type_map
300- self .register_buffer (
301- "buffer_type_map" , paddle .to_tensor ([ord (c ) for c in type_map ])
302- )
300+ if type_map is not None :
301+ self .register_buffer (
302+ "buffer_type_map" ,
303+ paddle .to_tensor ([ord (c ) for c in " " .join (type_map )]),
304+ )
303305 self .compress = False
304306 self .type_embedding = TypeEmbedNet (
305307 ntypes ,
Original file line number Diff line number Diff line change @@ -265,9 +265,11 @@ def init_subclass_params(sub_data, sub_class):
265265 self .use_econf_tebd = use_econf_tebd
266266 self .use_tebd_bias = use_tebd_bias
267267 self .type_map = type_map
268- self .register_buffer (
269- "buffer_type_map" , paddle .to_tensor ([ord (c ) for c in type_map ])
270- )
268+ if type_map is not None :
269+ self .register_buffer (
270+ "buffer_type_map" ,
271+ paddle .to_tensor ([ord (c ) for c in " " .join (type_map )]),
272+ )
271273 self .type_embedding = TypeEmbedNet (
272274 ntypes ,
273275 self .repinit_args .tebd_dim ,
Original file line number Diff line number Diff line change @@ -178,7 +178,7 @@ def init_subclass_params(sub_data, sub_class):
178178 self .type_map = type_map
179179 self .register_buffer (
180180 "buffer_type_map" ,
181- paddle .to_tensor ([ord (c ) for c in self .type_map ]),
181+ paddle .to_tensor ([ord (c ) for c in " " . join ( self .type_map ) ]),
182182 )
183183 self .tebd_dim = self .repflow_args .n_dim
184184 self .type_embedding = TypeEmbedNet (
Original file line number Diff line number Diff line change @@ -95,9 +95,11 @@ def __init__(
9595 raise NotImplementedError ("old implementation of spin is not supported." )
9696 super ().__init__ ()
9797 self .type_map = type_map
98- self .register_buffer (
99- "buffer_type_map" , paddle .to_tensor ([ord (c ) for c in type_map ])
100- )
98+ if type_map is not None :
99+ self .register_buffer (
100+ "buffer_type_map" ,
101+ paddle .to_tensor ([ord (c ) for c in " " .join (type_map )]),
102+ )
101103 self .compress = False
102104 self .prec = PRECISION_DICT [precision ]
103105 self .sea = DescrptBlockSeA (
Original file line number Diff line number Diff line change @@ -165,9 +165,11 @@ def __init__(
165165 self .prec = PRECISION_DICT [precision ]
166166 self .use_econf_tebd = use_econf_tebd
167167 self .type_map = type_map
168- self .register_buffer (
169- "buffer_type_map" , paddle .to_tensor ([ord (c ) for c in type_map ])
170- )
168+ if type_map is not None :
169+ self .register_buffer (
170+ "buffer_type_map" ,
171+ paddle .to_tensor ([ord (c ) for c in " " .join (type_map )]),
172+ )
171173 self .smooth = smooth
172174 self .type_embedding = TypeEmbedNet (
173175 ntypes ,
Original file line number Diff line number Diff line change @@ -263,10 +263,11 @@ def __init__(
263263 self .rcond = rcond
264264 self .seed = seed
265265 self .type_map = type_map
266- self .register_buffer (
267- "buffer_type_map" ,
268- paddle .to_tensor ([ord (c ) for c in self .type_map ]),
269- )
266+ if type_map is not None :
267+ self .register_buffer (
268+ "buffer_type_map" ,
269+ paddle .to_tensor ([ord (c ) for c in " " .join (self .type_map )]),
270+ )
270271 self .use_aparam_as_mask = use_aparam_as_mask
271272 # order matters, should be place after the assignment of ntypes
272273 self .reinit_exclude (exclude_types )
You can’t perform that action at this time.
0 commit comments