Skip to content

Commit e7925f3

Browse files
authored
feat(jax): energy, dos, dipole, polar, property atomic model & model (#4384)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced several new atomic model classes: `DPAtomicModelDipole`, `DPAtomicModelDOS`, `DPAtomicModelEnergy`, `DPAtomicModelPolar`, and `DPAtomicModelProperty`. - Added new model classes: `DipoleModel`, `DOSModel`, `PolarModel`, and `PropertyModel` for enhanced functionalities. - Implemented a new function to create JAX-compatible models from existing DP models, improving integration with JAX. - **Bug Fixes** - Enhanced test suite to support JAX backend, ensuring compatibility and flexibility in testing. - **Documentation** - Updated public API to include new models and functionalities. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 4334377 commit e7925f3

19 files changed

+910
-95
lines changed

deepmd/dpmodel/atomic_model/polar_atomic_model.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
import array_api_compat
34
import numpy as np
45

56
from deepmd.dpmodel.fitting.polarizability_fitting import (
@@ -34,29 +35,29 @@ def apply_out_stat(
3435
The atom types. nf x nloc
3536
3637
"""
38+
xp = array_api_compat.array_namespace(atype)
3739
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
3840

39-
if self.fitting_net.shift_diag:
41+
if self.fitting.shift_diag:
4042
nframes, nloc = atype.shape
4143
dtype = out_bias[self.bias_keys[0]].dtype
4244
for kk in self.bias_keys:
4345
ntypes = out_bias[kk].shape[0]
44-
temp = np.zeros(ntypes, dtype=dtype)
45-
temp = np.mean(
46-
np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
46+
temp = xp.mean(
47+
xp.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
4748
axis=1,
4849
)
4950
modified_bias = temp[atype]
5051

5152
# (nframes, nloc, 1)
5253
modified_bias = (
53-
modified_bias[..., np.newaxis] * (self.fitting_net.scale[atype])
54+
modified_bias[..., xp.newaxis] * (self.fitting.scale[atype])
5455
)
5556

56-
eye = np.eye(3, dtype=dtype)
57-
eye = np.tile(eye, (nframes, nloc, 1, 1))
57+
eye = xp.eye(3, dtype=dtype)
58+
eye = xp.tile(eye, (nframes, nloc, 1, 1))
5859
# (nframes, nloc, 3, 3)
59-
modified_bias = modified_bias[..., np.newaxis] * eye
60+
modified_bias = modified_bias[..., xp.newaxis] * eye
6061

6162
# nf x nloc x odims, out_bias: ntypes x odims
6263
ret[kk] = ret[kk] + modified_bias
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.atomic_model.dipole_atomic_model import (
3+
DPDipoleAtomicModel as DPAtomicModelDipoleDP,
4+
)
5+
from deepmd.jax.atomic_model.dp_atomic_model import (
6+
make_jax_dp_atomic_model_from_dpmodel,
7+
)
8+
9+
10+
class DPAtomicModelDipole(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDipoleDP)):
11+
pass
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.atomic_model.dos_atomic_model import (
3+
DPDOSAtomicModel as DPAtomicModelDOSDP,
4+
)
5+
from deepmd.jax.atomic_model.dp_atomic_model import (
6+
make_jax_dp_atomic_model_from_dpmodel,
7+
)
8+
9+
10+
class DPAtomicModelDOS(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDOSDP)):
11+
pass

deepmd/jax/atomic_model/dp_atomic_model.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,53 @@
2323
)
2424

2525

26-
@flax_module
27-
class DPAtomicModel(DPAtomicModelDP):
28-
base_descriptor_cls = BaseDescriptor
29-
"""The base descriptor class."""
30-
base_fitting_cls = BaseFitting
31-
"""The base fitting class."""
32-
33-
def __setattr__(self, name: str, value: Any) -> None:
34-
value = base_atomic_model_set_attr(name, value)
35-
return super().__setattr__(name, value)
36-
37-
def forward_common_atomic(
38-
self,
39-
extended_coord: jnp.ndarray,
40-
extended_atype: jnp.ndarray,
41-
nlist: jnp.ndarray,
42-
mapping: Optional[jnp.ndarray] = None,
43-
fparam: Optional[jnp.ndarray] = None,
44-
aparam: Optional[jnp.ndarray] = None,
45-
) -> dict[str, jnp.ndarray]:
46-
return super().forward_common_atomic(
47-
extended_coord,
48-
extended_atype,
49-
jax.lax.stop_gradient(nlist),
50-
mapping=mapping,
51-
fparam=fparam,
52-
aparam=aparam,
53-
)
26+
def make_jax_dp_atomic_model_from_dpmodel(
27+
dpmodel_atomic_model: type[DPAtomicModelDP],
28+
) -> type[DPAtomicModelDP]:
29+
"""Make a JAX backend DP atomic model from a DPModel backend DP atomic model.
30+
31+
Parameters
32+
----------
33+
dpmodel_atomic_model : type[DPAtomicModelDP]
34+
The DPModel backend DP atomic model.
35+
36+
Returns
37+
-------
38+
type[DPAtomicModel]
39+
The JAX backend DP atomic model.
40+
"""
41+
42+
@flax_module
43+
class jax_atomic_model(dpmodel_atomic_model):
44+
base_descriptor_cls = BaseDescriptor
45+
"""The base descriptor class."""
46+
base_fitting_cls = BaseFitting
47+
"""The base fitting class."""
48+
49+
def __setattr__(self, name: str, value: Any) -> None:
50+
value = base_atomic_model_set_attr(name, value)
51+
return super().__setattr__(name, value)
52+
53+
def forward_common_atomic(
54+
self,
55+
extended_coord: jnp.ndarray,
56+
extended_atype: jnp.ndarray,
57+
nlist: jnp.ndarray,
58+
mapping: Optional[jnp.ndarray] = None,
59+
fparam: Optional[jnp.ndarray] = None,
60+
aparam: Optional[jnp.ndarray] = None,
61+
) -> dict[str, jnp.ndarray]:
62+
return super().forward_common_atomic(
63+
extended_coord,
64+
extended_atype,
65+
jax.lax.stop_gradient(nlist),
66+
mapping=mapping,
67+
fparam=fparam,
68+
aparam=aparam,
69+
)
70+
71+
return jax_atomic_model
72+
73+
74+
class DPAtomicModel(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDP)):
75+
pass
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.atomic_model.energy_atomic_model import (
3+
DPEnergyAtomicModel as DPAtomicModelEnergyDP,
4+
)
5+
from deepmd.jax.atomic_model.dp_atomic_model import (
6+
make_jax_dp_atomic_model_from_dpmodel,
7+
)
8+
9+
10+
class DPAtomicModelEnergy(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelEnergyDP)):
11+
pass
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.atomic_model.polar_atomic_model import (
3+
DPPolarAtomicModel as DPAtomicModelPolarDP,
4+
)
5+
from deepmd.jax.atomic_model.dp_atomic_model import (
6+
make_jax_dp_atomic_model_from_dpmodel,
7+
)
8+
9+
10+
class DPAtomicModelPolar(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPolarDP)):
11+
pass
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.atomic_model.property_atomic_model import (
3+
DPPropertyAtomicModel as DPAtomicModelPropertyDP,
4+
)
5+
from deepmd.jax.atomic_model.dp_atomic_model import (
6+
make_jax_dp_atomic_model_from_dpmodel,
7+
)
8+
9+
10+
class DPAtomicModelProperty(
11+
make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPropertyDP)
12+
):
13+
pass

deepmd/jax/model/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .dipole_model import (
3+
DipoleModel,
4+
)
5+
from .dos_model import (
6+
DOSModel,
7+
)
28
from .dp_zbl_model import (
39
DPZBLLinearEnergyAtomicModel,
410
)
511
from .ener_model import (
612
EnergyModel,
713
)
14+
from .polar_model import (
15+
PolarModel,
16+
)
17+
from .property_model import (
18+
PropertyModel,
19+
)
820

921
__all__ = [
1022
"EnergyModel",
1123
"DPZBLLinearEnergyAtomicModel",
24+
"DOSModel",
25+
"DipoleModel",
26+
"PolarModel",
27+
"PropertyModel",
1228
]

deepmd/jax/model/dipole_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
from deepmd.dpmodel.model.dipole_model import DipoleModel as DipoleModelDP
4+
from deepmd.jax.atomic_model.dipole_atomic_model import (
5+
DPAtomicModelDipole,
6+
)
7+
from deepmd.jax.model.base_model import (
8+
BaseModel,
9+
)
10+
from deepmd.jax.model.dp_model import (
11+
make_jax_dp_model_from_dpmodel,
12+
)
13+
14+
15+
@BaseModel.register("dipole")
16+
class DipoleModel(make_jax_dp_model_from_dpmodel(DipoleModelDP, DPAtomicModelDipole)):
17+
pass

deepmd/jax/model/dos_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.model.dos_model import DOSModel as DOSModelDP
3+
from deepmd.jax.atomic_model.dos_atomic_model import (
4+
DPAtomicModelDOS,
5+
)
6+
from deepmd.jax.model.base_model import (
7+
BaseModel,
8+
)
9+
from deepmd.jax.model.dp_model import (
10+
make_jax_dp_model_from_dpmodel,
11+
)
12+
13+
14+
@BaseModel.register("dos")
15+
class DOSModel(make_jax_dp_model_from_dpmodel(DOSModelDP, DPAtomicModelDOS)):
16+
pass

0 commit comments

Comments
 (0)