Skip to content

Commit 088b252

Browse files
committed
fix bug in jax2tf model convert
1 parent 778af5b commit 088b252

File tree

7 files changed

+76
-13
lines changed

7 files changed

+76
-13
lines changed

deepmd/jax/atomic_model/dp_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def forward_common_atomic(
5858
mapping: Optional[jnp.ndarray] = None,
5959
fparam: Optional[jnp.ndarray] = None,
6060
aparam: Optional[jnp.ndarray] = None,
61+
atomic_weight: Optional[jnp.ndarray] = None,
6162
) -> dict[str, jnp.ndarray]:
6263
return super().forward_common_atomic(
6364
extended_coord,
@@ -66,6 +67,7 @@ def forward_common_atomic(
6667
mapping=mapping,
6768
fparam=fparam,
6869
aparam=aparam,
70+
atomic_weight=atomic_weight,
6971
)
7072

7173
return jax_atomic_model

deepmd/jax/jax2tf/make_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def model_call_from_call_lower(
3030
tnp.ndarray,
3131
tnp.ndarray,
3232
tnp.ndarray,
33+
tnp.ndarray,
3334
bool,
3435
],
3536
dict[str, tnp.ndarray],
@@ -43,6 +44,7 @@ def model_call_from_call_lower(
4344
box: tnp.ndarray,
4445
fparam: tnp.ndarray,
4546
aparam: tnp.ndarray,
47+
atomic_weight: tnp.ndarray,
4648
do_atomic_virial: bool = False,
4749
):
4850
"""Return model prediction from lower interface.
@@ -72,8 +74,8 @@ def model_call_from_call_lower(
7274
"""
7375
atype_shape = tf.shape(atype)
7476
nframes, nloc = atype_shape[0], atype_shape[1]
75-
cc, bb, fp, ap = coord, box, fparam, aparam
76-
del coord, box, fparam, aparam
77+
cc, bb, fp, ap, aw = coord, box, fparam, aparam, atomic_weight
78+
del coord, box, fparam, aparam, atomic_weight
7779
if tf.shape(bb)[-1] != 0:
7880
coord_normalized = normalize_coord(
7981
cc.reshape(nframes, nloc, 3),
@@ -102,6 +104,7 @@ def model_call_from_call_lower(
102104
mapping,
103105
fparam=fp,
104106
aparam=ap,
107+
atomic_weight=aw,
105108
)
106109
model_predict = communicate_extended_output(
107110
model_predict_lower,

deepmd/jax/jax2tf/serialization.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
4040

4141
def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms):
4242
def call_lower_with_fixed_do_atomic_virial(
43-
coord, atype, nlist, mapping, fparam, aparam
43+
coord,
44+
atype,
45+
nlist,
46+
mapping,
47+
fparam,
48+
aparam,
49+
atomic_weight,
4450
):
4551
return call_lower(
4652
coord,
@@ -49,6 +55,7 @@ def call_lower_with_fixed_do_atomic_virial(
4955
mapping,
5056
fparam,
5157
aparam,
58+
atomic_weight=atomic_weight,
5259
do_atomic_virial=do_atomic_virial,
5360
)
5461

@@ -68,12 +75,14 @@ def call_lower_with_fixed_do_atomic_virial(
6875
f"(nf, nloc + {nghost})",
6976
f"(nf, {model.get_dim_fparam()})",
7077
f"(nf, nloc, {model.get_dim_aparam()})",
78+
"(nf, nloc, 1)",
7179
],
7280
with_gradient=True,
7381
)
7482

7583
# Save a function that can take scalar inputs.
7684
# We need to explicit set the function name, so C++ can find it.
85+
# bug: replace 1 with fitting output dim
7786
@tf.function(
7887
autograph=False,
7988
input_signature=[
@@ -83,24 +92,32 @@ def call_lower_with_fixed_do_atomic_virial(
8392
tf.TensorSpec([None, None], tf.int64),
8493
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
8594
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
95+
tf.TensorSpec([None, None, 1], tf.float64),
8696
],
8797
)
8898
def call_lower_without_atomic_virial(
89-
coord, atype, nlist, mapping, fparam, aparam
99+
coord,
100+
atype,
101+
nlist,
102+
mapping,
103+
fparam,
104+
aparam,
105+
atomic_weight,
90106
):
91107
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
92108
return tf.cond(
93109
tf.shape(coord)[1] == tf.shape(nlist)[1],
94110
lambda: exported_whether_do_atomic_virial(
95111
do_atomic_virial=False, has_ghost_atoms=False
96-
)(coord, atype, nlist, mapping, fparam, aparam),
112+
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
97113
lambda: exported_whether_do_atomic_virial(
98114
do_atomic_virial=False, has_ghost_atoms=True
99-
)(coord, atype, nlist, mapping, fparam, aparam),
115+
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
100116
)
101117

102118
tf_model.call_lower = call_lower_without_atomic_virial
103119

120+
# bug: replace 1 with fitting output dim
104121
@tf.function(
105122
autograph=False,
106123
input_signature=[
@@ -110,18 +127,21 @@ def call_lower_without_atomic_virial(
110127
tf.TensorSpec([None, None], tf.int64),
111128
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
112129
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
130+
tf.TensorSpec([None, None, 1], tf.float64),
113131
],
114132
)
115-
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
133+
def call_lower_with_atomic_virial(
134+
coord, atype, nlist, mapping, fparam, aparam, atomic_weight
135+
):
116136
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
117137
return tf.cond(
118138
tf.shape(coord)[1] == tf.shape(nlist)[1],
119139
lambda: exported_whether_do_atomic_virial(
120140
do_atomic_virial=True, has_ghost_atoms=False
121-
)(coord, atype, nlist, mapping, fparam, aparam),
141+
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
122142
lambda: exported_whether_do_atomic_virial(
123143
do_atomic_virial=True, has_ghost_atoms=True
124-
)(coord, atype, nlist, mapping, fparam, aparam),
144+
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
125145
)
126146

127147
tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial
@@ -138,6 +158,7 @@ def call(
138158
box: Optional[tnp.ndarray] = None,
139159
fparam: Optional[tnp.ndarray] = None,
140160
aparam: Optional[tnp.ndarray] = None,
161+
atomic_weight: Optional[tnp.ndarray] = None,
141162
):
142163
"""Return model prediction.
143164
@@ -173,11 +194,13 @@ def call(
173194
box=box,
174195
fparam=fparam,
175196
aparam=aparam,
197+
atomic_weight=atomic_weight,
176198
do_atomic_virial=do_atomic_virial,
177199
)
178200

179201
return call
180202

203+
# bug: replace 1 with fitting output dim
181204
@tf.function(
182205
autograph=True,
183206
input_signature=[
@@ -186,6 +209,7 @@ def call(
186209
tf.TensorSpec([None, None, None], tf.float64),
187210
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
188211
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
212+
tf.TensorSpec([None, None, 1], tf.float64),
189213
],
190214
)
191215
def call_with_atomic_virial(
@@ -194,13 +218,15 @@ def call_with_atomic_virial(
194218
box: tnp.ndarray,
195219
fparam: tnp.ndarray,
196220
aparam: tnp.ndarray,
221+
atomic_weight: tnp.ndarray,
197222
):
198223
return make_call_whether_do_atomic_virial(do_atomic_virial=True)(
199-
coord, atype, box, fparam, aparam
224+
coord, atype, box, fparam, aparam, atomic_weight
200225
)
201226

202227
tf_model.call_atomic_virial = call_with_atomic_virial
203228

229+
# bug: replace 1 with fitting output dim
204230
@tf.function(
205231
autograph=True,
206232
input_signature=[
@@ -209,6 +235,7 @@ def call_with_atomic_virial(
209235
tf.TensorSpec([None, None, None], tf.float64),
210236
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
211237
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
238+
tf.TensorSpec([None, None, 1], tf.float64),
212239
],
213240
)
214241
def call_without_atomic_virial(
@@ -217,9 +244,10 @@ def call_without_atomic_virial(
217244
box: tnp.ndarray,
218245
fparam: tnp.ndarray,
219246
aparam: tnp.ndarray,
247+
atomic_weight: tnp.ndarray,
220248
):
221249
return make_call_whether_do_atomic_virial(do_atomic_virial=False)(
222-
coord, atype, box, fparam, aparam
250+
coord, atype, box, fparam, aparam, atomic_weight
223251
)
224252

225253
tf_model.call = call_without_atomic_virial

deepmd/jax/jax2tf/tfmodel.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __call__(
7979
fparam: Optional[jnp.ndarray] = None,
8080
aparam: Optional[jnp.ndarray] = None,
8181
do_atomic_virial: bool = False,
82+
atomic_weight: Optional[jnp.ndarray] = None,
8283
) -> Any:
8384
"""Return model prediction.
8485
@@ -105,7 +106,9 @@ def __call__(
105106
The keys are defined by the `ModelOutputDef`.
106107
107108
"""
108-
return self.call(coord, atype, box, fparam, aparam, do_atomic_virial)
109+
return self.call(
110+
coord, atype, box, fparam, aparam, do_atomic_virial, atomic_weight
111+
)
109112

110113
def call(
111114
self,
@@ -115,6 +118,7 @@ def call(
115118
fparam: Optional[jnp.ndarray] = None,
116119
aparam: Optional[jnp.ndarray] = None,
117120
do_atomic_virial: bool = False,
121+
atomic_weight: Optional[jnp.ndarray] = None,
118122
):
119123
"""Return model prediction.
120124
@@ -157,12 +161,17 @@ def call(
157161
(coord.shape[0], coord.shape[1], self.get_dim_aparam()),
158162
dtype=jnp.float64,
159163
)
164+
if atomic_weight is None:
165+
atomic_weight = jnp.empty(
166+
(coord.shape[0], coord.shape[1], 1), dtype=jnp.float64
167+
)
160168
return call(
161169
coord,
162170
atype,
163171
box,
164172
fparam,
165173
aparam,
174+
atomic_weight,
166175
)
167176

168177
def model_output_def(self):
@@ -179,6 +188,7 @@ def call_lower(
179188
fparam: Optional[jnp.ndarray] = None,
180189
aparam: Optional[jnp.ndarray] = None,
181190
do_atomic_virial: bool = False,
191+
atomic_weight: Optional[jnp.ndarray] = None,
182192
):
183193
if do_atomic_virial:
184194
call_lower = self._call_lower_atomic_virial
@@ -194,13 +204,18 @@ def call_lower(
194204
(extended_coord.shape[0], nlist.shape[1], self.get_dim_aparam()),
195205
dtype=jnp.float64,
196206
)
207+
if atomic_weight is None:
208+
atomic_weight = jnp.empty(
209+
(extended_coord.shape[0], nlist.shape[1], 1), dtype=jnp.float
210+
)
197211
return call_lower(
198212
extended_coord,
199213
extended_atype,
200214
nlist,
201215
mapping,
202216
fparam,
203217
aparam,
218+
atomic_weight,
204219
)
205220

206221
def get_type_map(self) -> list[str]:

deepmd/jax/model/base_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def forward_common_atomic(
2727
fparam: Optional[jnp.ndarray] = None,
2828
aparam: Optional[jnp.ndarray] = None,
2929
do_atomic_virial: bool = False,
30+
atomic_weight: Optional[jnp.ndarray] = None,
3031
):
3132
atomic_ret = self.atomic_model.forward_common_atomic(
3233
extended_coord,
@@ -35,6 +36,7 @@ def forward_common_atomic(
3536
mapping=mapping,
3637
fparam=fparam,
3738
aparam=aparam,
39+
atomic_weight=atomic_weight,
3840
)
3941
atomic_output_def = self.atomic_output_def()
4042
model_predict = {}
@@ -56,6 +58,7 @@ def eval_output(
5658
mapping,
5759
fparam,
5860
aparam,
61+
atomic_weight,
5962
*,
6063
_kk=kk,
6164
_atom_axis=atom_axis,
@@ -67,6 +70,9 @@ def eval_output(
6770
mapping=mapping[None, ...] if mapping is not None else None,
6871
fparam=fparam[None, ...] if fparam is not None else None,
6972
aparam=aparam[None, ...] if aparam is not None else None,
73+
atomic_weight=atomic_weight[None, ...]
74+
if atomic_weight is not None
75+
else None,
7076
)
7177
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)
7278

@@ -79,6 +85,7 @@ def eval_output(
7985
mapping,
8086
fparam,
8187
aparam,
88+
atomic_weight,
8289
)
8390
# extended_force: [nf, nall, *def, 3]
8491
def_ndim = len(vdef.shape)
@@ -101,6 +108,7 @@ def eval_ce(
101108
mapping,
102109
fparam,
103110
aparam,
111+
atomic_weight,
104112
*,
105113
_kk=kk,
106114
_atom_axis=atom_axis - 1,
@@ -113,6 +121,9 @@ def eval_ce(
113121
mapping=mapping[None, ...] if mapping is not None else None,
114122
fparam=fparam[None, ...] if fparam is not None else None,
115123
aparam=aparam[None, ...] if aparam is not None else None,
124+
atomic_weight=atomic_weight[None, ...]
125+
if atomic_weight is not None
126+
else None,
116127
)
117128
nloc = nlist.shape[0]
118129
cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...]
@@ -130,6 +141,7 @@ def eval_ce(
130141
mapping,
131142
fparam,
132143
aparam,
144+
atomic_weight,
133145
)
134146
# move the first 3 to the last
135147
# [nf, *def, nall, 3, 3]

deepmd/jax/model/dp_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def forward_common_atomic(
5656
fparam: Optional[jnp.ndarray] = None,
5757
aparam: Optional[jnp.ndarray] = None,
5858
do_atomic_virial: bool = False,
59+
atomic_weight: Optional[jnp.ndarray] = None,
5960
):
6061
return forward_common_atomic(
6162
self,
@@ -66,6 +67,7 @@ def forward_common_atomic(
6667
fparam=fparam,
6768
aparam=aparam,
6869
do_atomic_virial=do_atomic_virial,
70+
atomic_weight=atomic_weight,
6971
)
7072

7173
def format_nlist(

deepmd/jax/utils/serialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def exported_whether_do_atomic_virial(
5757
do_atomic_virial: bool, has_ghost_atoms: bool
5858
):
5959
def call_lower_with_fixed_do_atomic_virial(
60-
coord, atype, nlist, mapping, fparam, aparam
60+
coord, atype, nlist, mapping, fparam, aparam, atomic_weight
6161
):
6262
return call_lower(
6363
coord,
@@ -66,6 +66,7 @@ def call_lower_with_fixed_do_atomic_virial(
6666
mapping,
6767
fparam,
6868
aparam,
69+
atomic_weight,
6970
do_atomic_virial=do_atomic_virial,
7071
)
7172

0 commit comments

Comments
 (0)