Skip to content

Commit 25bb821

Browse files
njzjzcoderabbitai[bot]pre-commit-ci[bot]wanghan-iapcm
authored
feat(jax/array-api): DPA-2 (#4294)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced new classes for enhanced descriptor functionality, including `DescrptDPA2`, `DescrptBlockRepformers`, and `DescrptBlockSeTTebd`. - Added serialization and deserialization methods for better state management of descriptor objects. - **Improvements** - Enhanced compatibility with various array backends through the integration of `array_api_compat`. - Refactored existing methods to utilize new array API functions for improved performance. - Updated documentation to reflect JAX as a supported backend alongside PyTorch. - **Bug Fixes** - Updated handling of attributes in several classes to ensure correct deserialization and type safety. - **Tests** - Enhanced testing capabilities for JAX and Array API Strict backend integration, including conditional imports and new evaluation methods. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Han Wang <[email protected]>
1 parent 6bc730f commit 25bb821

File tree

9 files changed

+616
-118
lines changed

9 files changed

+616
-118
lines changed

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
Union,
55
)
66

7+
import array_api_compat
78
import numpy as np
89

910
from deepmd.dpmodel import (
1011
NativeOP,
1112
)
13+
from deepmd.dpmodel.array_api import (
14+
xp_take_along_axis,
15+
)
16+
from deepmd.dpmodel.common import (
17+
to_numpy_array,
18+
)
1219
from deepmd.dpmodel.utils import (
1320
EnvMat,
1421
NetworkCollection,
@@ -787,9 +794,10 @@ def call(
787794
The smooth switch function. shape: nf x nloc x nnei
788795
789796
"""
797+
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
790798
use_three_body = self.use_three_body
791799
nframes, nloc, nnei = nlist.shape
792-
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
800+
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3
793801
# nlists
794802
nlist_dict = build_multiple_neighbor_list(
795803
coord_ext,
@@ -798,7 +806,10 @@ def call(
798806
self.nsel_list,
799807
)
800808
# repinit
801-
g1_ext = self.type_embedding.call()[atype_ext]
809+
g1_ext = xp.reshape(
810+
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
811+
(nframes, nall, self.tebd_dim),
812+
)
802813
g1_inp = g1_ext[:, :nloc, :]
803814
g1, _, _, _, _ = self.repinit(
804815
nlist_dict[
@@ -823,16 +834,18 @@ def call(
823834
g1_ext,
824835
mapping,
825836
)
826-
g1 = np.concatenate([g1, g1_three_body], axis=-1)
837+
g1 = xp.concat([g1, g1_three_body], axis=-1)
827838
# linear to change shape
828839
g1 = self.g1_shape_tranform(g1)
829840
if self.add_tebd_to_repinit_out:
830841
assert self.tebd_transform is not None
831842
g1 = g1 + self.tebd_transform(g1_inp)
832843
# mapping g1
833844
assert mapping is not None
834-
mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1]))
835-
g1_ext = np.take_along_axis(g1, mapping_ext, axis=1)
845+
mapping_ext = xp.tile(
846+
xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1])
847+
)
848+
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
836849
# repformer
837850
g1, g2, h2, rot_mat, sw = self.repformers(
838851
nlist_dict[
@@ -846,7 +859,7 @@ def call(
846859
mapping,
847860
)
848861
if self.concat_output_tebd:
849-
g1 = np.concatenate([g1, g1_inp], axis=-1)
862+
g1 = xp.concat([g1, g1_inp], axis=-1)
850863
return g1, rot_mat, g2, h2, sw
851864

852865
def serialize(self) -> dict:
@@ -883,8 +896,8 @@ def serialize(self) -> dict:
883896
"embeddings": repinit.embeddings.serialize(),
884897
"env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(),
885898
"@variables": {
886-
"davg": repinit["davg"],
887-
"dstd": repinit["dstd"],
899+
"davg": to_numpy_array(repinit["davg"]),
900+
"dstd": to_numpy_array(repinit["dstd"]),
888901
},
889902
}
890903
if repinit.tebd_input_mode in ["strip"]:
@@ -896,8 +909,8 @@ def serialize(self) -> dict:
896909
"repformer_layers": [layer.serialize() for layer in repformers.layers],
897910
"env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(),
898911
"@variables": {
899-
"davg": repformers["davg"],
900-
"dstd": repformers["dstd"],
912+
"davg": to_numpy_array(repformers["davg"]),
913+
"dstd": to_numpy_array(repformers["dstd"]),
901914
},
902915
}
903916
data.update(
@@ -913,8 +926,8 @@ def serialize(self) -> dict:
913926
repinit_three_body.rcut, repinit_three_body.rcut_smth
914927
).serialize(),
915928
"@variables": {
916-
"davg": repinit_three_body["davg"],
917-
"dstd": repinit_three_body["dstd"],
929+
"davg": to_numpy_array(repinit_three_body["davg"]),
930+
"dstd": to_numpy_array(repinit_three_body["dstd"]),
918931
},
919932
}
920933
if repinit_three_body.tebd_input_mode in ["strip"]:

0 commit comments

Comments
 (0)