Skip to content

Commit 15bb00c

Browse files
authored
fix(pt/dp): make dpa2 convertable to .dp format (#4324)
Fix #4295. BTW, I found that there seems no universal uts for `convert-backend` command. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Updated `RepformerLayer` class to version 2, enhancing serialization and deserialization processes. - Introduced a new structure for residual variables within the serialized data, improving organization and clarity. - **Bug Fixes** - Adjusted version compatibility checks in the `deserialize` method to align with the new versioning scheme. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0199ad5 commit 15bb00c

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,7 @@ def serialize(self) -> dict:
17921792
"""
17931793
data = {
17941794
"@class": "RepformerLayer",
1795-
"@version": 1,
1795+
"@version": 2,
17961796
"rcut": self.rcut,
17971797
"rcut_smth": self.rcut_smth,
17981798
"sel": self.sel,
@@ -1877,9 +1877,11 @@ def serialize(self) -> dict:
18771877
if self.update_style == "res_residual":
18781878
data.update(
18791879
{
1880-
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
1881-
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
1882-
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
1880+
"@variables": {
1881+
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
1882+
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
1883+
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
1884+
}
18831885
}
18841886
)
18851887
return data
@@ -1894,7 +1896,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
18941896
The dict to deserialize from.
18951897
"""
18961898
data = data.copy()
1897-
check_version_compatibility(data.pop("@version"), 1, 1)
1899+
check_version_compatibility(data.pop("@version"), 2, 1)
18981900
data.pop("@class")
18991901
linear1 = data.pop("linear1")
19001902
update_chnnl_2 = data["update_chnnl_2"]
@@ -1915,9 +1917,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
19151917
attn2_ev_apply = data.pop("attn2_ev_apply", None)
19161918
loc_attn = data.pop("loc_attn", None)
19171919
g1_self_mlp = data.pop("g1_self_mlp", None)
1918-
g1_residual = data.pop("g1_residual", [])
1919-
g2_residual = data.pop("g2_residual", [])
1920-
h2_residual = data.pop("h2_residual", [])
1920+
variables = data.pop("@variables", {})
1921+
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
1922+
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
1923+
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))
19211924

19221925
obj = cls(**data)
19231926
obj.linear1 = NativeLayer.deserialize(linear1)

deepmd/pt/model/descriptor/repformer_layer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,7 @@ def serialize(self) -> dict:
12951295
"""
12961296
data = {
12971297
"@class": "RepformerLayer",
1298-
"@version": 1,
1298+
"@version": 2,
12991299
"rcut": self.rcut,
13001300
"rcut_smth": self.rcut_smth,
13011301
"sel": self.sel,
@@ -1380,9 +1380,11 @@ def serialize(self) -> dict:
13801380
if self.update_style == "res_residual":
13811381
data.update(
13821382
{
1383-
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
1384-
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
1385-
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
1383+
"@variables": {
1384+
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
1385+
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
1386+
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
1387+
}
13861388
}
13871389
)
13881390
return data
@@ -1397,7 +1399,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
13971399
The dict to deserialize from.
13981400
"""
13991401
data = data.copy()
1400-
check_version_compatibility(data.pop("@version"), 1, 1)
1402+
check_version_compatibility(data.pop("@version"), 2, 1)
14011403
data.pop("@class")
14021404
linear1 = data.pop("linear1")
14031405
update_chnnl_2 = data["update_chnnl_2"]
@@ -1418,9 +1420,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
14181420
attn2_ev_apply = data.pop("attn2_ev_apply", None)
14191421
loc_attn = data.pop("loc_attn", None)
14201422
g1_self_mlp = data.pop("g1_self_mlp", None)
1421-
g1_residual = data.pop("g1_residual", [])
1422-
g2_residual = data.pop("g2_residual", [])
1423-
h2_residual = data.pop("h2_residual", [])
1423+
variables = data.pop("@variables", {})
1424+
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
1425+
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
1426+
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))
14241427

14251428
obj = cls(**data)
14261429
obj.linear1 = MLPLayer.deserialize(linear1)

0 commit comments

Comments
 (0)