Skip to content
Merged
1 change: 1 addition & 0 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
DP_DTYPE_PROMOTION_STRICT: 1
if: matrix.group == 1
- run: mv .test_durations .test_durations_${{ matrix.group }}
- name: Upload partial durations
Expand Down
92 changes: 92 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
ABC,
abstractmethod,
)
from functools import (
wraps,
)
from typing import (
Any,
Callable,
Optional,
)

Expand Down Expand Up @@ -116,6 +120,94 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]:
return np.from_dlpack(x)


def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that casts and casts back the input
and output tensor of a method.

The decorator should be used in a classmethod.

The decorator will do the following thing:
(1) It casts input arrays from the global precision
to precision defined by property `precision`.
(2) It casts output arrays from `precision` to
the global precision.
(3) It checks inputs and outputs and only casts when
input or output is an array and its dtype matches
the global precision and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.

The decorator supports the array API.

Returns
-------
Callable
a decorator that casts and casts back the input and
output array of a method

Examples
--------
>>> class A:
... def __init__(self):
... self.precision = "float32"
...
... @cast_precision
... def f(x: Array, y: Array) -> Array:
... return x**2 + y
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# only convert tensors
returned_tensor = func(
self,
*[safe_cast_array(vv, "global", self.precision) for vv in args],
**{
kk: safe_cast_array(vv, "global", self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
return tuple(
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
)
else:
return safe_cast_array(returned_tensor, self.precision, "global")

return wrapper


def safe_cast_array(
input: np.ndarray, from_precision: str, to_precision: str
) -> np.ndarray:
"""Convert an array from a precision to another precision.

If input is not an array or without the specific precision, the method will not
cast it.

Array API is supported.

Parameters
----------
input : tf.Tensor
Input tensor
from_precision : str
Array data type that is casted from
to_precision : str
Array data type that casts to

Returns
-------
tf.Tensor
casted Tensor
"""
if array_api_compat.is_array_api_obj(input):
xp = array_api_compat.array_namespace(input)
if input.dtype == get_xp_precision(xp, from_precision):
return xp.astype(input, get_xp_precision(xp, to_precision))
return input


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -329,6 +330,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -448,6 +450,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -594,6 +595,7 @@ def init_subclass_params(sub_data, sub_class):
self.rcut = self.repinit.get_rcut()
self.ntypes = ntypes
self.sel = self.repinit.sel
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -757,6 +759,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
stddev_list.append(self.repinit_three_body.stddev)
return mean_list, stddev_list

@cast_precision
def call(
self,
coord_ext: np.ndarray,
Expand Down
14 changes: 5 additions & 9 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand All @@ -29,9 +30,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -340,6 +338,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -415,9 +414,7 @@ def call(
# nf x nloc x ng x ng1
grrg = np.einsum("flid,fljd->flij", gr, gr1)
# nf x nloc x (ng x ng1)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype(
GLOBAL_NP_FLOAT_PRECISION
)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
return grrg, gr[..., 1:], None, None, ww

def serialize(self) -> dict:
Expand Down Expand Up @@ -506,6 +503,7 @@ def update_sel(


class DescrptSeAArrayAPI(DescrptSeA):
@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -585,7 +583,5 @@ def call(
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = xp.astype(
xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype
)
grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron))
return grrg, gr[..., 1:], None, None, ww
3 changes: 2 additions & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -289,6 +290,7 @@ def cal_g(
gg = self.embeddings[(ll,)].call(ss)
return gg

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -352,7 +354,6 @@ def call(
res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = xp.reshape(res, (nf, nloc, ng))
res = xp.astype(res, get_xp_precision(xp, "global"))
return res, None, None, None, ww

def serialize(self) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -264,6 +265,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -317,7 +319,6 @@ def call(
# we don't require atype is the same in all frames
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, get_xp_precision(xp, self.precision))

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
Expand Down Expand Up @@ -349,7 +350,6 @@ def call(
result += res_ij
# nf x nloc x ng
result = xp.reshape(result, (nf, nloc, ng))
result = xp.astype(result, get_xp_precision(xp, "global"))
return result, None, None, None, ww

def serialize(self) -> dict:
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
get_xp_precision,
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -168,6 +168,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -287,6 +288,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -741,7 +743,6 @@ def call(
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
result = xp.astype(result, get_xp_precision(xp, "global"))
return (
result,
None,
Expand Down
22 changes: 15 additions & 7 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def _call_common(

"""
xp = array_api_compat.array_namespace(descriptor, atype)
descriptor = xp.astype(descriptor, get_xp_precision(xp, self.precision))
if fparam is not None:
fparam = xp.astype(fparam, get_xp_precision(xp, self.precision))
if aparam is not None:
aparam = xp.astype(aparam, get_xp_precision(xp, self.precision))
nf, nloc, nd = descriptor.shape
net_dim_out = self._net_out_dim()
# check input dim
Expand Down Expand Up @@ -439,18 +444,21 @@ def _call_common(
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i, ...]
atom_property = atom_property * xp.astype(mask, atom_property.dtype)
atom_property = xp.where(
mask, atom_property, xp.zeros_like(atom_property)
)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
[nf, nloc, net_dim_out],
)
outs = self.nets[()](xx)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
outs = xp.astype(outs, get_xp_precision(xp, "global"))
outs += xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
[nf, nloc, net_dim_out],
)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype)
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))}
3 changes: 3 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)

if os.environ.get("DP_DTYPE_PROMOTION_STRICT") == "1":
jax.config.update("jax_numpy_dtype_promotion", "strict")

__all__ = [
"jax",
"jnp",
Expand Down
Loading