Skip to content

Commit 4d50048

Browse files
feat(jax): checkpoint I/O (#4236)
Implement a JAX checkpoint format. I name it `*.jax` as I don't find existing conventions. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced serialization and deserialization functionalities for JAX backend models. - Added support for the `.jax` file suffix in the backend configuration. - Enhanced attribute handling logic across various classes to ensure proper processing of non-null values. - **Bug Fixes** - Enhanced cleanup processes in the test suite to improve reliability. - **Chores** - Updated dependencies in the project configuration for better JAX compatibility. - Adjusted linting rules to accommodate JAX-related code. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c870ccf commit 4d50048

File tree

11 files changed

+156
-10
lines changed

11 files changed

+156
-10
lines changed

deepmd/backend/jax.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,13 @@ class JAXBackend(Backend):
3232
name = "JAX"
3333
"""The formal name of the backend."""
3434
features: ClassVar[Backend.Feature] = (
35-
Backend.Feature(0)
35+
Backend.Feature.IO
3636
# Backend.Feature.ENTRY_POINT
3737
# | Backend.Feature.DEEP_EVAL
3838
# | Backend.Feature.NEIGHBOR_STAT
39-
# | Backend.Feature.IO
4039
)
4140
"""The features of the backend."""
42-
suffixes: ClassVar[list[str]] = []
41+
suffixes: ClassVar[list[str]] = [".jax"]
4342
"""The suffixes of the backend."""
4443

4544
def is_available(self) -> bool:
@@ -94,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
9493
Callable[[str], dict]
9594
The serialize hook of the backend.
9695
"""
97-
raise NotImplementedError
96+
from deepmd.jax.utils.serialization import (
97+
serialize_from_file,
98+
)
99+
100+
return serialize_from_file
98101

99102
@property
100103
def deserialize_hook(self) -> Callable[[str, dict], None]:
@@ -105,4 +108,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
105108
Callable[[str, dict], None]
106109
The deserialize hook of the backend.
107110
"""
108-
raise NotImplementedError
111+
from deepmd.jax.utils.serialization import (
112+
deserialize_to_file,
113+
)
114+
115+
return deserialize_to_file

deepmd/jax/atomic_model/base_atomic_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from deepmd.jax.common import (
3+
ArrayAPIVariable,
34
to_jax_array,
45
)
56
from deepmd.jax.utils.exclude_mask import (
@@ -11,6 +12,8 @@
1112
def base_atomic_model_set_attr(name, value):
1213
if name in {"out_bias", "out_std"}:
1314
value = to_jax_array(value)
15+
if value is not None:
16+
value = ArrayAPIVariable(value)
1417
elif name == "pair_excl" and value is not None:
1518
value = PairExcludeMask(value.ntypes, value.exclude_types)
1619
elif name == "atom_excl" and value is not None:

deepmd/jax/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,17 @@ def __setattr__(self, name: str, value: Any) -> None:
8181
return super().__setattr__(name, value)
8282

8383
return FlaxModule
84+
85+
86+
class ArrayAPIVariable(nnx.Variable):
87+
def __array__(self, *args, **kwargs):
88+
return self.value.__array__(*args, **kwargs)
89+
90+
def __array_namespace__(self, *args, **kwargs):
91+
return self.value.__array_namespace__(*args, **kwargs)
92+
93+
def __dlpack__(self, *args, **kwargs):
94+
return self.value.__dlpack__(*args, **kwargs)
95+
96+
def __dlpack_device__(self, *args, **kwargs):
97+
return self.value.__dlpack_device__(*args, **kwargs)

deepmd/jax/descriptor/dpa1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP,
1414
)
1515
from deepmd.jax.common import (
16+
ArrayAPIVariable,
1617
flax_module,
1718
to_jax_array,
1819
)
@@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP):
6566
def __setattr__(self, name: str, value: Any) -> None:
6667
if name in {"mean", "stddev"}:
6768
value = to_jax_array(value)
69+
if value is not None:
70+
value = ArrayAPIVariable(value)
6871
elif name in {"embeddings", "embeddings_strip"}:
6972
if value is not None:
7073
value = NetworkCollection.deserialize(value.serialize())

deepmd/jax/descriptor/se_e2_a.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
77
from deepmd.jax.common import (
8+
ArrayAPIVariable,
89
flax_module,
910
to_jax_array,
1011
)
@@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP):
2627
def __setattr__(self, name: str, value: Any) -> None:
2728
if name in {"dstd", "davg"}:
2829
value = to_jax_array(value)
30+
if value is not None:
31+
value = ArrayAPIVariable(value)
2932
elif name in {"embeddings"}:
3033
if value is not None:
3134
value = NetworkCollection.deserialize(value.serialize())

deepmd/jax/fitting/fitting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
77
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
88
from deepmd.jax.common import (
9+
ArrayAPIVariable,
910
flax_module,
1011
to_jax_array,
1112
)
@@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
2930
"aparam_inv_std",
3031
}:
3132
value = to_jax_array(value)
33+
if value is not None:
34+
value = ArrayAPIVariable(value)
3235
elif name == "emask":
3336
value = AtomExcludeMask(value.ntypes, value.exclude_types)
3437
elif name == "nets":

deepmd/jax/utils/exclude_mask.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
77
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
88
from deepmd.jax.common import (
9+
ArrayAPIVariable,
910
flax_module,
1011
to_jax_array,
1112
)
@@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP):
1617
def __setattr__(self, name: str, value: Any) -> None:
1718
if name in {"type_mask"}:
1819
value = to_jax_array(value)
20+
if value is not None:
21+
value = ArrayAPIVariable(value)
1922
return super().__setattr__(name, value)
2023

2124

@@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP):
2427
def __setattr__(self, name: str, value: Any) -> None:
2528
if name in {"type_mask"}:
2629
value = to_jax_array(value)
30+
if value is not None:
31+
value = ArrayAPIVariable(value)
2732
return super().__setattr__(name, value)

deepmd/jax/utils/serialization.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from pathlib import (
3+
Path,
4+
)
5+
6+
import orbax.checkpoint as ocp
7+
8+
from deepmd.jax.env import (
9+
jax,
10+
nnx,
11+
)
12+
from deepmd.jax.model.model import (
13+
BaseModel,
14+
get_model,
15+
)
16+
17+
18+
def deserialize_to_file(model_file: str, data: dict) -> None:
19+
"""Deserialize the dictionary to a model file.
20+
21+
Parameters
22+
----------
23+
model_file : str
24+
The model file to be saved.
25+
data : dict
26+
The dictionary to be deserialized.
27+
"""
28+
if model_file.endswith(".jax"):
29+
model = BaseModel.deserialize(data["model"])
30+
model_def_script = data["model_def_script"]
31+
_, state = nnx.split(model)
32+
with ocp.Checkpointer(
33+
ocp.CompositeCheckpointHandler("state", "model_def_script")
34+
) as checkpointer:
35+
checkpointer.save(
36+
Path(model_file).absolute(),
37+
ocp.args.Composite(
38+
state=ocp.args.StandardSave(state.to_pure_dict()),
39+
model_def_script=ocp.args.JsonSave(model_def_script),
40+
),
41+
)
42+
else:
43+
raise ValueError("JAX backend only supports converting .jax directory")
44+
45+
46+
def serialize_from_file(model_file: str) -> dict:
47+
"""Serialize the model file to a dictionary.
48+
49+
Parameters
50+
----------
51+
model_file : str
52+
The model file to be serialized.
53+
54+
Returns
55+
-------
56+
dict
57+
The serialized model data.
58+
"""
59+
if model_file.endswith(".jax"):
60+
with ocp.Checkpointer(
61+
ocp.CompositeCheckpointHandler("state", "model_def_script")
62+
) as checkpointer:
63+
data = checkpointer.restore(
64+
Path(model_file).absolute(),
65+
ocp.args.Composite(
66+
state=ocp.args.StandardRestore(),
67+
model_def_script=ocp.args.JsonRestore(),
68+
),
69+
)
70+
state = data.state
71+
72+
# convert str "1" to int 1 key
73+
def convert_str_to_int_key(item: dict):
74+
for key, value in item.copy().items():
75+
if isinstance(value, dict):
76+
convert_str_to_int_key(value)
77+
if key.isdigit():
78+
item[int(key)] = item.pop(key)
79+
80+
convert_str_to_int_key(state)
81+
82+
model_def_script = data.model_def_script
83+
abstract_model = get_model(model_def_script)
84+
graphdef, abstract_state = nnx.split(abstract_model)
85+
abstract_state.replace_by_pure_dict(state)
86+
model = nnx.merge(graphdef, abstract_state)
87+
model_dict = model.serialize()
88+
data = {
89+
"backend": "JAX",
90+
"jax_version": jax.__version__,
91+
"model": model_dict,
92+
"model_def_script": model_def_script,
93+
"@variables": {},
94+
}
95+
return data
96+
else:
97+
raise ValueError("JAX backend only supports converting .jax directory")

deepmd/jax/utils/type_embed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
77
from deepmd.jax.common import (
8+
ArrayAPIVariable,
89
flax_module,
910
to_jax_array,
1011
)
@@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP):
1819
def __setattr__(self, name: str, value: Any) -> None:
1920
if name in {"econf_tebd"}:
2021
value = to_jax_array(value)
22+
if value is not None:
23+
value = ArrayAPIVariable(value)
2124
if name in {"embedding_net"}:
2225
value = EmbeddingNet.deserialize(value.serialize())
2326
return super().__setattr__(name, value)

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ cu12 = [
137137
]
138138
jax = [
139139
'jax>=0.4.33;python_version>="3.10"',
140-
'flax>=0.8.0;python_version>="3.10"',
140+
'flax>=0.10.0;python_version>="3.10"',
141+
'orbax-checkpoint;python_version>="3.10"',
142+
# The pinning of ml_dtypes may conflict with TF
143+
# 'jax-ai-stack;python_version>="3.10"',
141144
]
142145

143146
[tool.deepmd_build_backend.scripts]
@@ -402,6 +405,7 @@ banned-module-level-imports = [
402405
# Also ignore `E402` in all `__init__.py` files.
403406
"deepmd/tf/**" = ["TID253"]
404407
"deepmd/pt/**" = ["TID253"]
408+
"deepmd/jax/**" = ["TID253"]
405409
"source/tests/tf/**" = ["TID253"]
406410
"source/tests/pt/**" = ["TID253"]
407411
"source/tests/universal/pt/**" = ["TID253"]

0 commit comments

Comments
 (0)