@@ -63,6 +63,7 @@ def model_call_from_call_lower(
6363 fparam : Optional [np .ndarray ] = None ,
6464 aparam : Optional [np .ndarray ] = None ,
6565 do_atomic_virial : bool = False ,
66+ atomic_weight : Optional [np .ndarray ] = None ,
6667):
6768 """Return model prediction from lower interface.
6869
@@ -121,6 +122,7 @@ def model_call_from_call_lower(
121122 fparam = fp ,
122123 aparam = ap ,
123124 do_atomic_virial = do_atomic_virial ,
125+ atomic_weight = atomic_weight ,
124126 )
125127 model_predict = communicate_extended_output (
126128 model_predict_lower ,
@@ -224,6 +226,7 @@ def call(
224226 fparam : Optional [np .ndarray ] = None ,
225227 aparam : Optional [np .ndarray ] = None ,
226228 do_atomic_virial : bool = False ,
229+ atomic_weight : Optional [np .ndarray ] = None ,
227230 ) -> dict [str , np .ndarray ]:
228231 """Return model prediction.
229232
@@ -250,8 +253,12 @@ def call(
250253 The keys are defined by the `ModelOutputDef`.
251254
252255 """
253- cc , bb , fp , ap , input_prec = self .input_type_cast (
254- coord , box = box , fparam = fparam , aparam = aparam
256+ cc , bb , fp , ap , aw , input_prec = self .input_type_cast (
257+ coord ,
258+ box = box ,
259+ fparam = fparam ,
260+ aparam = aparam ,
261+ atomic_weight = atomic_weight ,
255262 )
256263 del coord , box , fparam , aparam
257264 model_predict = model_call_from_call_lower (
@@ -266,6 +273,7 @@ def call(
266273 fparam = fp ,
267274 aparam = ap ,
268275 do_atomic_virial = do_atomic_virial ,
276+ atomic_weight = aw ,
269277 )
270278 model_predict = self .output_type_cast (model_predict , input_prec )
271279 return model_predict
@@ -279,6 +287,7 @@ def call_lower(
279287 fparam : Optional [np .ndarray ] = None ,
280288 aparam : Optional [np .ndarray ] = None ,
281289 do_atomic_virial : bool = False ,
290+ atomic_weight : Optional [np .ndarray ] = None ,
282291 ):
283292 """Return model prediction. Lower interface that takes
284293 extended atomic coordinates and types, nlist, and mapping
@@ -316,8 +325,11 @@ def call_lower(
316325 nlist ,
317326 extra_nlist_sort = self .need_sorted_nlist_for_lower (),
318327 )
319- cc_ext , _ , fp , ap , input_prec = self .input_type_cast (
320- extended_coord , fparam = fparam , aparam = aparam
328+ cc_ext , _ , fp , ap , aw , input_prec = self .input_type_cast (
329+ extended_coord ,
330+ fparam = fparam ,
331+ aparam = aparam ,
332+ atomic_weight = atomic_weight ,
321333 )
322334 del extended_coord , fparam , aparam
323335 model_predict = self .forward_common_atomic (
@@ -328,6 +340,7 @@ def call_lower(
328340 fparam = fp ,
329341 aparam = ap ,
330342 do_atomic_virial = do_atomic_virial ,
343+ atomic_weight = aw ,
331344 )
332345 model_predict = self .output_type_cast (model_predict , input_prec )
333346 return model_predict
@@ -341,6 +354,7 @@ def forward_common_atomic(
341354 fparam : Optional [np .ndarray ] = None ,
342355 aparam : Optional [np .ndarray ] = None ,
343356 do_atomic_virial : bool = False ,
357+ atomic_weight : Optional [np .ndarray ] = None ,
344358 ):
345359 atomic_ret = self .atomic_model .forward_common_atomic (
346360 extended_coord ,
@@ -349,6 +363,7 @@ def forward_common_atomic(
349363 mapping = mapping ,
350364 fparam = fparam ,
351365 aparam = aparam ,
366+ atomic_weight = atomic_weight ,
352367 )
353368 return fit_output_to_model_output (
354369 atomic_ret ,
@@ -365,11 +380,13 @@ def input_type_cast(
365380 box : Optional [np .ndarray ] = None ,
366381 fparam : Optional [np .ndarray ] = None ,
367382 aparam : Optional [np .ndarray ] = None ,
383+ atomic_weight : Optional [np .ndarray ] = None ,
368384 ) -> tuple [
369385 np .ndarray ,
370386 Optional [np .ndarray ],
371387 Optional [np .ndarray ],
372388 Optional [np .ndarray ],
389+ Optional [np .ndarray ],
373390 str ,
374391 ]:
375392 """Cast the input data to global float type."""
@@ -379,18 +396,19 @@ def input_type_cast(
379396 ###
380397 _lst : list [Optional [np .ndarray ]] = [
381398 vv .astype (coord .dtype ) if vv is not None else None
382- for vv in [box , fparam , aparam ]
399+ for vv in [box , fparam , aparam , atomic_weight ]
383400 ]
384- box , fparam , aparam = _lst
401+ box , fparam , aparam , atomic_weight = _lst
385402 if input_prec == RESERVED_PRECISION_DICT [self .global_np_float_precision ]:
386- return coord , box , fparam , aparam , input_prec
403+ return coord , box , fparam , aparam , atomic_weight , input_prec
387404 else :
388405 pp = self .global_np_float_precision
389406 return (
390407 coord .astype (pp ),
391408 box .astype (pp ) if box is not None else None ,
392409 fparam .astype (pp ) if fparam is not None else None ,
393410 aparam .astype (pp ) if aparam is not None else None ,
411+ atomic_weight .astype (pp ) if atomic_weight is not None else None ,
394412 input_prec ,
395413 )
396414
0 commit comments