Skip to content

Commit c12bc01

Browse files
authored
feat(pt): calculate stat during compression if --skip-neighbor-stat (#4330)
If `--skip-neighbor-stat` is set during training, when calling `dp compress`, first calculate the neighbor stat. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced `enable_compression` function to accept a `training_script` parameter for improved error handling and functionality. - Updated the `compress` command to allow specification of a training script during execution. - Introduced a new testing framework for models using the `--skip-neighbor-stat` flag, validating their functionality. - **Bug Fixes** - Improved error handling for cases where the model's minimum neighbor distance is not saved. - **Tests** - Added a new test class and methods to validate the functionality of models initialized with skip neighbor statistics. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent cb3e39e commit c12bc01

File tree

3 files changed

+212
-1
lines changed

3 files changed

+212
-1
lines changed

deepmd/pt/entrypoints/compress.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import json
3+
import logging
4+
from typing import (
5+
Optional,
6+
)
37

48
import torch
59

10+
from deepmd.common import (
11+
j_loader,
12+
)
613
from deepmd.pt.model.model import (
714
get_model,
815
)
16+
from deepmd.pt.utils import (
17+
env,
18+
)
19+
from deepmd.pt.utils.update_sel import (
20+
UpdateSel,
21+
)
22+
from deepmd.utils.compat import (
23+
update_deepmd_input,
24+
)
25+
from deepmd.utils.data_system import (
26+
get_data,
27+
)
28+
29+
log = logging.getLogger(__name__)
930

1031

1132
def enable_compression(
@@ -14,12 +35,44 @@ def enable_compression(
1435
stride: float = 0.01,
1536
extrapolate: int = 5,
1637
check_frequency: int = -1,
38+
training_script: Optional[str] = None,
1739
):
1840
saved_model = torch.jit.load(input_file, map_location="cpu")
1941
model_def_script = json.loads(saved_model.model_def_script)
2042
model = get_model(model_def_script)
2143
model.load_state_dict(saved_model.state_dict())
2244

45+
if model.get_min_nbor_dist() is None:
46+
log.info(
47+
"Minimal neighbor distance is not saved in the model, compute it from the training data."
48+
)
49+
if training_script is None:
50+
raise ValueError(
51+
"The model does not have a minimum neighbor distance, "
52+
"so the training script and data must be provided "
53+
"(via -t,--training-script)."
54+
)
55+
56+
jdata = j_loader(training_script)
57+
jdata = update_deepmd_input(jdata)
58+
59+
type_map = jdata["model"].get("type_map", None)
60+
train_data = get_data(
61+
jdata["training"]["training_data"],
62+
0, # not used
63+
type_map,
64+
None,
65+
)
66+
update_sel = UpdateSel()
67+
t_min_nbor_dist = update_sel.get_min_nbor_dist(
68+
train_data,
69+
)
70+
model.min_nbor_dist = torch.tensor(
71+
t_min_nbor_dist,
72+
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
73+
device=env.DEVICE,
74+
)
75+
2376
model.enable_compression(
2477
extrapolate,
2578
stride,

deepmd/pt/entrypoints/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
565565
stride=FLAGS.step,
566566
extrapolate=FLAGS.extrapolate,
567567
check_frequency=FLAGS.frequency,
568+
training_script=FLAGS.training_script,
568569
)
569570
else:
570571
raise RuntimeError(f"Invalid command {FLAGS.command}!")

source/tests/pt/test_model_compression_se_a.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,48 @@ def _init_models_exclude_types():
7474
return INPUT, frozen_model, compressed_model
7575

7676

77+
def _init_models_skip_neighbor_stat():
78+
suffix = "-skip-neighbor-stat"
79+
data_file = str(tests_path / os.path.join("model_compression", "data"))
80+
frozen_model = str(tests_path / f"dp-original{suffix}.pth")
81+
compressed_model = str(tests_path / f"dp-compressed{suffix}.pth")
82+
INPUT = str(tests_path / "input.json")
83+
jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json")))
84+
jdata["training"]["training_data"]["systems"] = data_file
85+
with open(INPUT, "w") as fp:
86+
json.dump(jdata, fp, indent=4)
87+
88+
ret = run_dp("dp --pt train " + INPUT + " --skip-neighbor-stat")
89+
np.testing.assert_equal(ret, 0, "DP train failed!")
90+
ret = run_dp("dp --pt freeze -o " + frozen_model)
91+
np.testing.assert_equal(ret, 0, "DP freeze failed!")
92+
ret = run_dp(
93+
"dp --pt compress "
94+
+ " -i "
95+
+ frozen_model
96+
+ " -o "
97+
+ compressed_model
98+
+ " -t "
99+
+ INPUT
100+
)
101+
np.testing.assert_equal(ret, 0, "DP model compression failed!")
102+
return INPUT, frozen_model, compressed_model
103+
104+
77105
def setUpModule():
78106
global \
79107
INPUT, \
80108
FROZEN_MODEL, \
81109
COMPRESSED_MODEL, \
82110
INPUT_ET, \
83111
FROZEN_MODEL_ET, \
84-
COMPRESSED_MODEL_ET
112+
COMPRESSED_MODEL_ET, \
113+
FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \
114+
COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT
85115
INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models()
116+
_, FROZEN_MODEL_SKIP_NEIGHBOR_STAT, COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = (
117+
_init_models_skip_neighbor_stat()
118+
)
86119
INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types()
87120

88121

@@ -572,5 +605,129 @@ def test_2frame_atm(self):
572605
np.testing.assert_almost_equal(vv0, vv1, default_places)
573606

574607

608+
class TestSkipNeighborStat(unittest.TestCase):
609+
@classmethod
610+
def setUpClass(cls):
611+
cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT)
612+
cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT)
613+
cls.coords = np.array(
614+
[
615+
12.83,
616+
2.56,
617+
2.18,
618+
12.09,
619+
2.87,
620+
2.74,
621+
00.25,
622+
3.32,
623+
1.68,
624+
3.36,
625+
3.00,
626+
1.81,
627+
3.51,
628+
2.51,
629+
2.60,
630+
4.27,
631+
3.22,
632+
1.56,
633+
]
634+
)
635+
cls.atype = [0, 1, 1, 0, 1, 1]
636+
cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])
637+
638+
def test_attrs(self):
639+
self.assertEqual(self.dp_original.get_ntypes(), 2)
640+
self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places)
641+
self.assertEqual(self.dp_original.get_type_map(), ["O", "H"])
642+
self.assertEqual(self.dp_original.get_dim_fparam(), 0)
643+
self.assertEqual(self.dp_original.get_dim_aparam(), 0)
644+
645+
self.assertEqual(self.dp_compressed.get_ntypes(), 2)
646+
self.assertAlmostEqual(
647+
self.dp_compressed.get_rcut(), 6.0, places=default_places
648+
)
649+
self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"])
650+
self.assertEqual(self.dp_compressed.get_dim_fparam(), 0)
651+
self.assertEqual(self.dp_compressed.get_dim_aparam(), 0)
652+
653+
def test_1frame(self):
654+
ee0, ff0, vv0 = self.dp_original.eval(
655+
self.coords, self.box, self.atype, atomic=False
656+
)
657+
ee1, ff1, vv1 = self.dp_compressed.eval(
658+
self.coords, self.box, self.atype, atomic=False
659+
)
660+
# check shape of the returns
661+
nframes = 1
662+
natoms = len(self.atype)
663+
self.assertEqual(ee0.shape, (nframes, 1))
664+
self.assertEqual(ff0.shape, (nframes, natoms, 3))
665+
self.assertEqual(vv0.shape, (nframes, 9))
666+
self.assertEqual(ee1.shape, (nframes, 1))
667+
self.assertEqual(ff1.shape, (nframes, natoms, 3))
668+
self.assertEqual(vv1.shape, (nframes, 9))
669+
# check values
670+
np.testing.assert_almost_equal(ff0, ff1, default_places)
671+
np.testing.assert_almost_equal(ee0, ee1, default_places)
672+
np.testing.assert_almost_equal(vv0, vv1, default_places)
673+
674+
def test_1frame_atm(self):
675+
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
676+
self.coords, self.box, self.atype, atomic=True
677+
)
678+
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
679+
self.coords, self.box, self.atype, atomic=True
680+
)
681+
# check shape of the returns
682+
nframes = 1
683+
natoms = len(self.atype)
684+
self.assertEqual(ee0.shape, (nframes, 1))
685+
self.assertEqual(ff0.shape, (nframes, natoms, 3))
686+
self.assertEqual(vv0.shape, (nframes, 9))
687+
self.assertEqual(ae0.shape, (nframes, natoms, 1))
688+
self.assertEqual(av0.shape, (nframes, natoms, 9))
689+
self.assertEqual(ee1.shape, (nframes, 1))
690+
self.assertEqual(ff1.shape, (nframes, natoms, 3))
691+
self.assertEqual(vv1.shape, (nframes, 9))
692+
self.assertEqual(ae1.shape, (nframes, natoms, 1))
693+
self.assertEqual(av1.shape, (nframes, natoms, 9))
694+
# check values
695+
np.testing.assert_almost_equal(ff0, ff1, default_places)
696+
np.testing.assert_almost_equal(ae0, ae1, default_places)
697+
np.testing.assert_almost_equal(av0, av1, default_places)
698+
np.testing.assert_almost_equal(ee0, ee1, default_places)
699+
np.testing.assert_almost_equal(vv0, vv1, default_places)
700+
701+
def test_2frame_atm(self):
702+
coords2 = np.concatenate((self.coords, self.coords))
703+
box2 = np.concatenate((self.box, self.box))
704+
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
705+
coords2, box2, self.atype, atomic=True
706+
)
707+
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
708+
coords2, box2, self.atype, atomic=True
709+
)
710+
# check shape of the returns
711+
nframes = 2
712+
natoms = len(self.atype)
713+
self.assertEqual(ee0.shape, (nframes, 1))
714+
self.assertEqual(ff0.shape, (nframes, natoms, 3))
715+
self.assertEqual(vv0.shape, (nframes, 9))
716+
self.assertEqual(ae0.shape, (nframes, natoms, 1))
717+
self.assertEqual(av0.shape, (nframes, natoms, 9))
718+
self.assertEqual(ee1.shape, (nframes, 1))
719+
self.assertEqual(ff1.shape, (nframes, natoms, 3))
720+
self.assertEqual(vv1.shape, (nframes, 9))
721+
self.assertEqual(ae1.shape, (nframes, natoms, 1))
722+
self.assertEqual(av1.shape, (nframes, natoms, 9))
723+
724+
# check values
725+
np.testing.assert_almost_equal(ff0, ff1, default_places)
726+
np.testing.assert_almost_equal(ae0, ae1, default_places)
727+
np.testing.assert_almost_equal(av0, av1, default_places)
728+
np.testing.assert_almost_equal(ee0, ee1, default_places)
729+
np.testing.assert_almost_equal(vv0, vv1, default_places)
730+
731+
575732
if __name__ == "__main__":
576733
unittest.main()

0 commit comments

Comments
 (0)