Skip to content

Commit 627947a

Browse files
committed
hlo
Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent a66afd3 commit 627947a

File tree

9 files changed

+810
-9
lines changed

9 files changed

+810
-9
lines changed

deepmd/backend/jax.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class JAXBackend(Backend):
3838
# | Backend.Feature.NEIGHBOR_STAT
3939
)
4040
"""The features of the backend."""
41-
suffixes: ClassVar[list[str]] = [".jax"]
41+
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
4242
"""The suffixes of the backend."""
4343

4444
def is_available(self) -> bool:
@@ -71,7 +71,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
7171
type[DeepEvalBackend]
7272
The Deep Eval backend of the backend.
7373
"""
74-
raise NotImplementedError
74+
from deepmd.jax.infer.deep_eval import (
75+
DeepEval,
76+
)
77+
78+
return DeepEval
7579

7680
@property
7781
def neighbor_stat(self) -> type["NeighborStat"]:

deepmd/dpmodel/descriptor/se_e2_a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def call(
555555
coord_ext, atype_ext, nlist, self.davg, self.dstd
556556
)
557557
nf, nloc, nnei, _ = rr.shape
558-
sec = xp.asarray(self.sel_cumsum)
558+
sec = self.sel_cumsum
559559

560560
ng = self.neuron[-1]
561561
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)

deepmd/dpmodel/utils/serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def save_dp_model(filename: str, model_dict: dict) -> None:
9090
# use UTC+0 time
9191
"time": str(datetime.datetime.now(tz=datetime.timezone.utc)),
9292
}
93-
if filename_extension == ".dp":
93+
if filename_extension in (".dp", ".hlo"):
9494
variable_counter = Counter()
9595
with h5py.File(filename, "w") as f:
9696
model_dict = traverse_model_dict(
@@ -141,7 +141,7 @@ def load_dp_model(filename: str) -> dict:
141141
The loaded model dict, including meta information.
142142
"""
143143
filename_extension = Path(filename).suffix
144-
if filename_extension == ".dp":
144+
if filename_extension in {".dp", ".hlo"}:
145145
with h5py.File(filename, "r") as f:
146146
model_dict = json.loads(f.attrs["json"])
147147
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())

deepmd/jax/env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from flax import (
99
nnx,
1010
)
11+
from jax import export as jax_export
1112

1213
jax.config.update("jax_enable_x64", True)
1314

1415
__all__ = [
1516
"jax",
1617
"jnp",
1718
"nnx",
19+
"jax_export",
1820
]

deepmd/jax/infer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

0 commit comments

Comments
 (0)