@@ -74,15 +74,48 @@ def _init_models_exclude_types():
7474 return INPUT , frozen_model , compressed_model
7575
7676
77+ def _init_models_skip_neighbor_stat ():
78+ suffix = "-skip-neighbor-stat"
79+ data_file = str (tests_path / os .path .join ("model_compression" , "data" ))
80+ frozen_model = str (tests_path / f"dp-original{ suffix } .pth" )
81+ compressed_model = str (tests_path / f"dp-compressed{ suffix } .pth" )
82+ INPUT = str (tests_path / "input.json" )
83+ jdata = j_loader (str (tests_path / os .path .join ("model_compression" , "input.json" )))
84+ jdata ["training" ]["training_data" ]["systems" ] = data_file
85+ with open (INPUT , "w" ) as fp :
86+ json .dump (jdata , fp , indent = 4 )
87+
88+ ret = run_dp ("dp --pt train " + INPUT + " --skip-neighbor-stat" )
89+ np .testing .assert_equal (ret , 0 , "DP train failed!" )
90+ ret = run_dp ("dp --pt freeze -o " + frozen_model )
91+ np .testing .assert_equal (ret , 0 , "DP freeze failed!" )
92+ ret = run_dp (
93+ "dp --pt compress "
94+ + " -i "
95+ + frozen_model
96+ + " -o "
97+ + compressed_model
98+ + " -t "
99+ + INPUT
100+ )
101+ np .testing .assert_equal (ret , 0 , "DP model compression failed!" )
102+ return INPUT , frozen_model , compressed_model
103+
104+
77105def setUpModule ():
78106 global \
79107 INPUT , \
80108 FROZEN_MODEL , \
81109 COMPRESSED_MODEL , \
82110 INPUT_ET , \
83111 FROZEN_MODEL_ET , \
84- COMPRESSED_MODEL_ET
112+ COMPRESSED_MODEL_ET , \
113+ FROZEN_MODEL_SKIP_NEIGHBOR_STAT , \
114+ COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT
85115 INPUT , FROZEN_MODEL , COMPRESSED_MODEL = _init_models ()
116+ _ , FROZEN_MODEL_SKIP_NEIGHBOR_STAT , COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = (
117+ _init_models_skip_neighbor_stat ()
118+ )
86119 INPUT_ET , FROZEN_MODEL_ET , COMPRESSED_MODEL_ET = _init_models_exclude_types ()
87120
88121
@@ -572,5 +605,129 @@ def test_2frame_atm(self):
572605 np .testing .assert_almost_equal (vv0 , vv1 , default_places )
573606
574607
608+ class TestSkipNeighborStat (unittest .TestCase ):
609+ @classmethod
610+ def setUpClass (cls ):
611+ cls .dp_original = DeepEval (FROZEN_MODEL_SKIP_NEIGHBOR_STAT )
612+ cls .dp_compressed = DeepEval (COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT )
613+ cls .coords = np .array (
614+ [
615+ 12.83 ,
616+ 2.56 ,
617+ 2.18 ,
618+ 12.09 ,
619+ 2.87 ,
620+ 2.74 ,
621+ 00.25 ,
622+ 3.32 ,
623+ 1.68 ,
624+ 3.36 ,
625+ 3.00 ,
626+ 1.81 ,
627+ 3.51 ,
628+ 2.51 ,
629+ 2.60 ,
630+ 4.27 ,
631+ 3.22 ,
632+ 1.56 ,
633+ ]
634+ )
635+ cls .atype = [0 , 1 , 1 , 0 , 1 , 1 ]
636+ cls .box = np .array ([13.0 , 0.0 , 0.0 , 0.0 , 13.0 , 0.0 , 0.0 , 0.0 , 13.0 ])
637+
638+ def test_attrs (self ):
639+ self .assertEqual (self .dp_original .get_ntypes (), 2 )
640+ self .assertAlmostEqual (self .dp_original .get_rcut (), 6.0 , places = default_places )
641+ self .assertEqual (self .dp_original .get_type_map (), ["O" , "H" ])
642+ self .assertEqual (self .dp_original .get_dim_fparam (), 0 )
643+ self .assertEqual (self .dp_original .get_dim_aparam (), 0 )
644+
645+ self .assertEqual (self .dp_compressed .get_ntypes (), 2 )
646+ self .assertAlmostEqual (
647+ self .dp_compressed .get_rcut (), 6.0 , places = default_places
648+ )
649+ self .assertEqual (self .dp_compressed .get_type_map (), ["O" , "H" ])
650+ self .assertEqual (self .dp_compressed .get_dim_fparam (), 0 )
651+ self .assertEqual (self .dp_compressed .get_dim_aparam (), 0 )
652+
653+ def test_1frame (self ):
654+ ee0 , ff0 , vv0 = self .dp_original .eval (
655+ self .coords , self .box , self .atype , atomic = False
656+ )
657+ ee1 , ff1 , vv1 = self .dp_compressed .eval (
658+ self .coords , self .box , self .atype , atomic = False
659+ )
660+ # check shape of the returns
661+ nframes = 1
662+ natoms = len (self .atype )
663+ self .assertEqual (ee0 .shape , (nframes , 1 ))
664+ self .assertEqual (ff0 .shape , (nframes , natoms , 3 ))
665+ self .assertEqual (vv0 .shape , (nframes , 9 ))
666+ self .assertEqual (ee1 .shape , (nframes , 1 ))
667+ self .assertEqual (ff1 .shape , (nframes , natoms , 3 ))
668+ self .assertEqual (vv1 .shape , (nframes , 9 ))
669+ # check values
670+ np .testing .assert_almost_equal (ff0 , ff1 , default_places )
671+ np .testing .assert_almost_equal (ee0 , ee1 , default_places )
672+ np .testing .assert_almost_equal (vv0 , vv1 , default_places )
673+
674+ def test_1frame_atm (self ):
675+ ee0 , ff0 , vv0 , ae0 , av0 = self .dp_original .eval (
676+ self .coords , self .box , self .atype , atomic = True
677+ )
678+ ee1 , ff1 , vv1 , ae1 , av1 = self .dp_compressed .eval (
679+ self .coords , self .box , self .atype , atomic = True
680+ )
681+ # check shape of the returns
682+ nframes = 1
683+ natoms = len (self .atype )
684+ self .assertEqual (ee0 .shape , (nframes , 1 ))
685+ self .assertEqual (ff0 .shape , (nframes , natoms , 3 ))
686+ self .assertEqual (vv0 .shape , (nframes , 9 ))
687+ self .assertEqual (ae0 .shape , (nframes , natoms , 1 ))
688+ self .assertEqual (av0 .shape , (nframes , natoms , 9 ))
689+ self .assertEqual (ee1 .shape , (nframes , 1 ))
690+ self .assertEqual (ff1 .shape , (nframes , natoms , 3 ))
691+ self .assertEqual (vv1 .shape , (nframes , 9 ))
692+ self .assertEqual (ae1 .shape , (nframes , natoms , 1 ))
693+ self .assertEqual (av1 .shape , (nframes , natoms , 9 ))
694+ # check values
695+ np .testing .assert_almost_equal (ff0 , ff1 , default_places )
696+ np .testing .assert_almost_equal (ae0 , ae1 , default_places )
697+ np .testing .assert_almost_equal (av0 , av1 , default_places )
698+ np .testing .assert_almost_equal (ee0 , ee1 , default_places )
699+ np .testing .assert_almost_equal (vv0 , vv1 , default_places )
700+
701+ def test_2frame_atm (self ):
702+ coords2 = np .concatenate ((self .coords , self .coords ))
703+ box2 = np .concatenate ((self .box , self .box ))
704+ ee0 , ff0 , vv0 , ae0 , av0 = self .dp_original .eval (
705+ coords2 , box2 , self .atype , atomic = True
706+ )
707+ ee1 , ff1 , vv1 , ae1 , av1 = self .dp_compressed .eval (
708+ coords2 , box2 , self .atype , atomic = True
709+ )
710+ # check shape of the returns
711+ nframes = 2
712+ natoms = len (self .atype )
713+ self .assertEqual (ee0 .shape , (nframes , 1 ))
714+ self .assertEqual (ff0 .shape , (nframes , natoms , 3 ))
715+ self .assertEqual (vv0 .shape , (nframes , 9 ))
716+ self .assertEqual (ae0 .shape , (nframes , natoms , 1 ))
717+ self .assertEqual (av0 .shape , (nframes , natoms , 9 ))
718+ self .assertEqual (ee1 .shape , (nframes , 1 ))
719+ self .assertEqual (ff1 .shape , (nframes , natoms , 3 ))
720+ self .assertEqual (vv1 .shape , (nframes , 9 ))
721+ self .assertEqual (ae1 .shape , (nframes , natoms , 1 ))
722+ self .assertEqual (av1 .shape , (nframes , natoms , 9 ))
723+
724+ # check values
725+ np .testing .assert_almost_equal (ff0 , ff1 , default_places )
726+ np .testing .assert_almost_equal (ae0 , ae1 , default_places )
727+ np .testing .assert_almost_equal (av0 , av1 , default_places )
728+ np .testing .assert_almost_equal (ee0 , ee1 , default_places )
729+ np .testing .assert_almost_equal (vv0 , vv1 , default_places )
730+
731+
575732if __name__ == "__main__" :
576733 unittest .main ()
0 commit comments