@@ -40,7 +40,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
4040
4141 def exported_whether_do_atomic_virial (do_atomic_virial , has_ghost_atoms ):
4242 def call_lower_with_fixed_do_atomic_virial (
43- coord , atype , nlist , mapping , fparam , aparam
43+ coord ,
44+ atype ,
45+ nlist ,
46+ mapping ,
47+ fparam ,
48+ aparam ,
49+ atomic_weight ,
4450 ):
4551 return call_lower (
4652 coord ,
@@ -49,6 +55,7 @@ def call_lower_with_fixed_do_atomic_virial(
4955 mapping ,
5056 fparam ,
5157 aparam ,
58+ atomic_weight = atomic_weight ,
5259 do_atomic_virial = do_atomic_virial ,
5360 )
5461
@@ -68,12 +75,14 @@ def call_lower_with_fixed_do_atomic_virial(
6875 f"(nf, nloc + { nghost } )" ,
6976 f"(nf, { model .get_dim_fparam ()} )" ,
7077 f"(nf, nloc, { model .get_dim_aparam ()} )" ,
78+ "(nf, nloc, 1)" ,
7179 ],
7280 with_gradient = True ,
7381 )
7482
7583 # Save a function that can take scalar inputs.
7684 # We need to explicit set the function name, so C++ can find it.
85+ # bug: replace 1 with fitting output dim
7786 @tf .function (
7887 autograph = False ,
7988 input_signature = [
@@ -83,24 +92,32 @@ def call_lower_with_fixed_do_atomic_virial(
8392 tf .TensorSpec ([None , None ], tf .int64 ),
8493 tf .TensorSpec ([None , model .get_dim_fparam ()], tf .float64 ),
8594 tf .TensorSpec ([None , None , model .get_dim_aparam ()], tf .float64 ),
95+ tf .TensorSpec ([None , None , 1 ], tf .float64 ),
8696 ],
8797 )
8898 def call_lower_without_atomic_virial (
89- coord , atype , nlist , mapping , fparam , aparam
99+ coord ,
100+ atype ,
101+ nlist ,
102+ mapping ,
103+ fparam ,
104+ aparam ,
105+ atomic_weight ,
90106 ):
91107 nlist = format_nlist (coord , nlist , model .get_nnei (), model .get_rcut ())
92108 return tf .cond (
93109 tf .shape (coord )[1 ] == tf .shape (nlist )[1 ],
94110 lambda : exported_whether_do_atomic_virial (
95111 do_atomic_virial = False , has_ghost_atoms = False
96- )(coord , atype , nlist , mapping , fparam , aparam ),
112+ )(coord , atype , nlist , mapping , fparam , aparam , atomic_weight ),
97113 lambda : exported_whether_do_atomic_virial (
98114 do_atomic_virial = False , has_ghost_atoms = True
99- )(coord , atype , nlist , mapping , fparam , aparam ),
115+ )(coord , atype , nlist , mapping , fparam , aparam , atomic_weight ),
100116 )
101117
102118 tf_model .call_lower = call_lower_without_atomic_virial
103119
120+ # bug: replace 1 with fitting output dim
104121 @tf .function (
105122 autograph = False ,
106123 input_signature = [
@@ -110,18 +127,21 @@ def call_lower_without_atomic_virial(
110127 tf .TensorSpec ([None , None ], tf .int64 ),
111128 tf .TensorSpec ([None , model .get_dim_fparam ()], tf .float64 ),
112129 tf .TensorSpec ([None , None , model .get_dim_aparam ()], tf .float64 ),
130+ tf .TensorSpec ([None , None , 1 ], tf .float64 ),
113131 ],
114132 )
115- def call_lower_with_atomic_virial (coord , atype , nlist , mapping , fparam , aparam ):
133+ def call_lower_with_atomic_virial (
134+ coord , atype , nlist , mapping , fparam , aparam , atomic_weight
135+ ):
116136 nlist = format_nlist (coord , nlist , model .get_nnei (), model .get_rcut ())
117137 return tf .cond (
118138 tf .shape (coord )[1 ] == tf .shape (nlist )[1 ],
119139 lambda : exported_whether_do_atomic_virial (
120140 do_atomic_virial = True , has_ghost_atoms = False
121- )(coord , atype , nlist , mapping , fparam , aparam ),
141+ )(coord , atype , nlist , mapping , fparam , aparam , atomic_weight ),
122142 lambda : exported_whether_do_atomic_virial (
123143 do_atomic_virial = True , has_ghost_atoms = True
124- )(coord , atype , nlist , mapping , fparam , aparam ),
144+ )(coord , atype , nlist , mapping , fparam , aparam , atomic_weight ),
125145 )
126146
127147 tf_model .call_lower_atomic_virial = call_lower_with_atomic_virial
@@ -138,6 +158,7 @@ def call(
138158 box : Optional [tnp .ndarray ] = None ,
139159 fparam : Optional [tnp .ndarray ] = None ,
140160 aparam : Optional [tnp .ndarray ] = None ,
161+ atomic_weight : Optional [tnp .ndarray ] = None ,
141162 ):
142163 """Return model prediction.
143164
@@ -173,11 +194,13 @@ def call(
173194 box = box ,
174195 fparam = fparam ,
175196 aparam = aparam ,
197+ atomic_weight = atomic_weight ,
176198 do_atomic_virial = do_atomic_virial ,
177199 )
178200
179201 return call
180202
203+ # bug: replace 1 with fitting output dim
181204 @tf .function (
182205 autograph = True ,
183206 input_signature = [
@@ -186,6 +209,7 @@ def call(
186209 tf .TensorSpec ([None , None , None ], tf .float64 ),
187210 tf .TensorSpec ([None , model .get_dim_fparam ()], tf .float64 ),
188211 tf .TensorSpec ([None , None , model .get_dim_aparam ()], tf .float64 ),
212+ tf .TensorSpec ([None , None , 1 ], tf .float64 ),
189213 ],
190214 )
191215 def call_with_atomic_virial (
@@ -194,13 +218,15 @@ def call_with_atomic_virial(
194218 box : tnp .ndarray ,
195219 fparam : tnp .ndarray ,
196220 aparam : tnp .ndarray ,
221+ atomic_weight : tnp .ndarray ,
197222 ):
198223 return make_call_whether_do_atomic_virial (do_atomic_virial = True )(
199- coord , atype , box , fparam , aparam
224+ coord , atype , box , fparam , aparam , atomic_weight
200225 )
201226
202227 tf_model .call_atomic_virial = call_with_atomic_virial
203228
229+ # bug: replace 1 with fitting output dim
204230 @tf .function (
205231 autograph = True ,
206232 input_signature = [
@@ -209,6 +235,7 @@ def call_with_atomic_virial(
209235 tf .TensorSpec ([None , None , None ], tf .float64 ),
210236 tf .TensorSpec ([None , model .get_dim_fparam ()], tf .float64 ),
211237 tf .TensorSpec ([None , None , model .get_dim_aparam ()], tf .float64 ),
238+ tf .TensorSpec ([None , None , 1 ], tf .float64 ),
212239 ],
213240 )
214241 def call_without_atomic_virial (
@@ -217,9 +244,10 @@ def call_without_atomic_virial(
217244 box : tnp .ndarray ,
218245 fparam : tnp .ndarray ,
219246 aparam : tnp .ndarray ,
247+ atomic_weight : tnp .ndarray ,
220248 ):
221249 return make_call_whether_do_atomic_virial (do_atomic_virial = False )(
222- coord , atype , box , fparam , aparam
250+ coord , atype , box , fparam , aparam , atomic_weight
223251 )
224252
225253 tf_model .call = call_without_atomic_virial
0 commit comments