Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def forward_common_atomic(
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
atomic_weight: Optional[np.ndarray] = None,
) -> dict[str, np.ndarray]:
"""Common interface for atomic inference.

Expand All @@ -170,6 +171,9 @@ def forward_common_atomic(
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
atomic_weight
atomic weights for scaling outputs, shape: nf x nloc x dim_aw
if provided, all output values will be multiplied by this weight.

Returns
-------
Expand Down Expand Up @@ -213,6 +217,11 @@ def forward_common_atomic(
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
if atomic_weight is not None:
_out_shape = ret_dict[kk].shape
ret_dict[kk] = ret_dict[kk] * atomic_weight.reshape(
[_out_shape[0], _out_shape[1], -1]
)
ret_dict["mask"] = xp.astype(atom_mask, xp.int32)

return ret_dict
Expand All @@ -225,6 +234,7 @@ def call(
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
atomic_weight: Optional[np.ndarray] = None,
) -> dict[str, np.ndarray]:
return self.forward_common_atomic(
extended_coord,
Expand All @@ -233,6 +243,7 @@ def call(
mapping=mapping,
fparam=fparam,
aparam=aparam,
atomic_weight=atomic_weight,
)

def serialize(self) -> dict:
Expand Down
32 changes: 25 additions & 7 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def model_call_from_call_lower(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
atomic_weight: Optional[np.ndarray] = None,
):
"""Return model prediction from lower interface.

Expand Down Expand Up @@ -121,6 +122,7 @@ def model_call_from_call_lower(
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
atomic_weight=atomic_weight,
)
model_predict = communicate_extended_output(
model_predict_lower,
Expand Down Expand Up @@ -224,6 +226,7 @@ def call(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
atomic_weight: Optional[np.ndarray] = None,
) -> dict[str, np.ndarray]:
"""Return model prediction.

Expand All @@ -250,8 +253,12 @@ def call(
The keys are defined by the `ModelOutputDef`.

"""
cc, bb, fp, ap, input_prec = self.input_type_cast(
coord, box=box, fparam=fparam, aparam=aparam
cc, bb, fp, ap, aw, input_prec = self.input_type_cast(
coord,
box=box,
fparam=fparam,
aparam=aparam,
atomic_weight=atomic_weight,
)
del coord, box, fparam, aparam
model_predict = model_call_from_call_lower(
Expand All @@ -266,6 +273,7 @@ def call(
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
atomic_weight=aw,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
Expand All @@ -279,6 +287,7 @@ def call_lower(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
atomic_weight: Optional[np.ndarray] = None,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand Down Expand Up @@ -316,8 +325,11 @@ def call_lower(
nlist,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
extended_coord, fparam=fparam, aparam=aparam
cc_ext, _, fp, ap, aw, input_prec = self.input_type_cast(
extended_coord,
fparam=fparam,
aparam=aparam,
atomic_weight=atomic_weight,
)
del extended_coord, fparam, aparam
model_predict = self.forward_common_atomic(
Expand All @@ -328,6 +340,7 @@ def call_lower(
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
atomic_weight=aw,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
Expand All @@ -341,6 +354,7 @@ def forward_common_atomic(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
atomic_weight: Optional[np.ndarray] = None,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
Expand All @@ -349,6 +363,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
atomic_weight=atomic_weight,
)
return fit_output_to_model_output(
atomic_ret,
Expand All @@ -365,11 +380,13 @@ def input_type_cast(
box: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
atomic_weight: Optional[np.ndarray] = None,
) -> tuple[
np.ndarray,
Optional[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
str,
]:
"""Cast the input data to global float type."""
Expand All @@ -379,18 +396,19 @@ def input_type_cast(
###
_lst: list[Optional[np.ndarray]] = [
vv.astype(coord.dtype) if vv is not None else None
for vv in [box, fparam, aparam]
for vv in [box, fparam, aparam, atomic_weight]
]
box, fparam, aparam = _lst
box, fparam, aparam, atomic_weight = _lst
if input_prec == RESERVED_PRECISION_DICT[self.global_np_float_precision]:
return coord, box, fparam, aparam, input_prec
return coord, box, fparam, aparam, atomic_weight, input_prec
else:
pp = self.global_np_float_precision
return (
coord.astype(pp),
box.astype(pp) if box is not None else None,
fparam.astype(pp) if fparam is not None else None,
aparam.astype(pp) if aparam is not None else None,
atomic_weight.astype(pp) if atomic_weight is not None else None,
input_prec,
)

Expand Down
2 changes: 2 additions & 0 deletions deepmd/jax/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def forward_common_atomic(
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
atomic_weight: Optional[jnp.ndarray] = None,
) -> dict[str, jnp.ndarray]:
return super().forward_common_atomic(
extended_coord,
Expand All @@ -66,6 +67,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
atomic_weight=atomic_weight,
)

return jax_atomic_model
Expand Down
7 changes: 5 additions & 2 deletions deepmd/jax/jax2tf/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def model_call_from_call_lower(
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
bool,
],
dict[str, tnp.ndarray],
Expand All @@ -43,6 +44,7 @@ def model_call_from_call_lower(
box: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
atomic_weight: tnp.ndarray,
do_atomic_virial: bool = False,
):
"""Return model prediction from lower interface.
Expand Down Expand Up @@ -72,8 +74,8 @@ def model_call_from_call_lower(
"""
atype_shape = tf.shape(atype)
nframes, nloc = atype_shape[0], atype_shape[1]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
cc, bb, fp, ap, aw = coord, box, fparam, aparam, atomic_weight
del coord, box, fparam, aparam, atomic_weight
if tf.shape(bb)[-1] != 0:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
Expand Down Expand Up @@ -102,6 +104,7 @@ def model_call_from_call_lower(
mapping,
fparam=fp,
aparam=ap,
atomic_weight=aw,
)
model_predict = communicate_extended_output(
model_predict_lower,
Expand Down
46 changes: 37 additions & 9 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None:

def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms):
def call_lower_with_fixed_do_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
coord,
atype,
nlist,
mapping,
fparam,
aparam,
atomic_weight,
):
return call_lower(
coord,
Expand All @@ -49,6 +55,7 @@ def call_lower_with_fixed_do_atomic_virial(
mapping,
fparam,
aparam,
atomic_weight=atomic_weight,
do_atomic_virial=do_atomic_virial,
)

Expand All @@ -68,12 +75,14 @@ def call_lower_with_fixed_do_atomic_virial(
f"(nf, nloc + {nghost})",
f"(nf, {model.get_dim_fparam()})",
f"(nf, nloc, {model.get_dim_aparam()})",
"(nf, nloc, 1)",
],
with_gradient=True,
)

# Save a function that can take scalar inputs.
# We need to explicit set the function name, so C++ can find it.
# bug: replace 1 with fitting output dim
@tf.function(
autograph=False,
input_signature=[
Expand All @@ -83,24 +92,32 @@ def call_lower_with_fixed_do_atomic_virial(
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
tf.TensorSpec([None, None, 1], tf.float64),
],
)
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
coord,
atype,
nlist,
mapping,
fparam,
aparam,
atomic_weight,
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
do_atomic_virial=False, has_ghost_atoms=False
)(coord, atype, nlist, mapping, fparam, aparam),
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
lambda: exported_whether_do_atomic_virial(
do_atomic_virial=False, has_ghost_atoms=True
)(coord, atype, nlist, mapping, fparam, aparam),
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
)

tf_model.call_lower = call_lower_without_atomic_virial

# bug: replace 1 with fitting output dim
@tf.function(
autograph=False,
input_signature=[
Expand All @@ -110,18 +127,21 @@ def call_lower_without_atomic_virial(
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
tf.TensorSpec([None, None, 1], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
def call_lower_with_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam, atomic_weight
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
do_atomic_virial=True, has_ghost_atoms=False
)(coord, atype, nlist, mapping, fparam, aparam),
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
lambda: exported_whether_do_atomic_virial(
do_atomic_virial=True, has_ghost_atoms=True
)(coord, atype, nlist, mapping, fparam, aparam),
)(coord, atype, nlist, mapping, fparam, aparam, atomic_weight),
)

tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial
Expand All @@ -138,6 +158,7 @@ def call(
box: Optional[tnp.ndarray] = None,
fparam: Optional[tnp.ndarray] = None,
aparam: Optional[tnp.ndarray] = None,
atomic_weight: Optional[tnp.ndarray] = None,
):
"""Return model prediction.

Expand Down Expand Up @@ -173,11 +194,13 @@ def call(
box=box,
fparam=fparam,
aparam=aparam,
atomic_weight=atomic_weight,
do_atomic_virial=do_atomic_virial,
)

return call

# bug: replace 1 with fitting output dim
@tf.function(
autograph=True,
input_signature=[
Expand All @@ -186,6 +209,7 @@ def call(
tf.TensorSpec([None, None, None], tf.float64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
tf.TensorSpec([None, None, 1], tf.float64),
],
)
def call_with_atomic_virial(
Expand All @@ -194,13 +218,15 @@ def call_with_atomic_virial(
box: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
atomic_weight: tnp.ndarray,
):
return make_call_whether_do_atomic_virial(do_atomic_virial=True)(
coord, atype, box, fparam, aparam
coord, atype, box, fparam, aparam, atomic_weight
)

tf_model.call_atomic_virial = call_with_atomic_virial

# bug: replace 1 with fitting output dim
@tf.function(
autograph=True,
input_signature=[
Expand All @@ -209,6 +235,7 @@ def call_with_atomic_virial(
tf.TensorSpec([None, None, None], tf.float64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
tf.TensorSpec([None, None, 1], tf.float64),
],
)
def call_without_atomic_virial(
Expand All @@ -217,9 +244,10 @@ def call_without_atomic_virial(
box: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
atomic_weight: tnp.ndarray,
):
return make_call_whether_do_atomic_virial(do_atomic_virial=False)(
coord, atype, box, fparam, aparam
coord, atype, box, fparam, aparam, atomic_weight
)

tf_model.call = call_without_atomic_virial
Expand Down
Loading
Loading