77)
88
99import torch
10+ import torch .nn as nn
1011
1112from deepmd .dpmodel .utils import EnvMat as DPEnvMat
1213from deepmd .dpmodel .utils .seed import (
3940from deepmd .pt .utils .exclude_mask import (
4041 PairExcludeMask ,
4142)
43+ from deepmd .pt .utils .tabulate import (
44+ DPTabulate ,
45+ )
4246from deepmd .pt .utils .update_sel import (
4347 UpdateSel ,
4448)
49+ from deepmd .pt .utils .utils import (
50+ ActivationFn ,
51+ )
4552from deepmd .utils .data_system import (
4653 DeepmdDataSystem ,
4754)
@@ -181,6 +188,7 @@ def __init__(
181188 self .tebd_input_mode = tebd_input_mode
182189 self .concat_output_tebd = concat_output_tebd
183190 self .trainable = trainable
191+ self .compress = False
184192 # set trainable
185193 for param in self .parameters ():
186194 param .requires_grad = trainable
@@ -516,6 +524,83 @@ def update_sel(
516524 local_jdata_cpy ["sel" ] = sel [0 ]
517525 return local_jdata_cpy , min_nbor_dist
518526
527+ def enable_compression (
528+ self ,
529+ min_nbor_dist : float ,
530+ table_extrapolate : float = 5 ,
531+ table_stride_1 : float = 0.01 ,
532+ table_stride_2 : float = 0.1 ,
533+ check_frequency : int = - 1 ,
534+ ) -> None :
535+ """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data.
536+
537+ Parameters
538+ ----------
539+ min_nbor_dist
540+ The nearest distance between atoms
541+ table_extrapolate
542+ The scale of model extrapolation
543+ table_stride_1
544+ The uniform stride of the first table
545+ table_stride_2
546+ The uniform stride of the second table
547+ check_frequency
548+ The overflow check frequency
549+ """
550+ # do some checks before the model compression process
551+ if self .compress :
552+ raise ValueError ("Compression is already enabled." )
553+ assert not self .se_ttebd .resnet_dt , (
554+ "Model compression error: descriptor resnet_dt must be false!"
555+ )
556+ for tt in self .se_ttebd .exclude_types :
557+ if (tt [0 ] not in range (self .se_ttebd .ntypes )) or (
558+ tt [1 ] not in range (self .se_ttebd .ntypes )
559+ ):
560+ raise RuntimeError (
561+ "exclude types"
562+ + str (tt )
563+ + " must within the number of atomic types "
564+ + str (self .se_ttebd .ntypes )
565+ + "!"
566+ )
567+ if (
568+ self .se_ttebd .ntypes * self .se_ttebd .ntypes
569+ - len (self .se_ttebd .exclude_types )
570+ == 0
571+ ):
572+ raise RuntimeError (
573+ "Empty embedding-nets are not supported in model compression!"
574+ )
575+
576+ if self .tebd_input_mode != "strip" :
577+ raise RuntimeError ("Cannot compress model when tebd_input_mode == 'concat'" )
578+
579+ data = self .serialize ()
580+ self .table = DPTabulate (
581+ self ,
582+ data ["neuron" ],
583+ exclude_types = data ["exclude_types" ],
584+ activation_fn = ActivationFn (data ["activation_function" ]),
585+ )
586+ # Scale the stride values for SE_T descriptor
587+ stride_1_scaled = table_stride_1 * 10
588+ stride_2_scaled = table_stride_2 * 10
589+ self .table_config = [
590+ table_extrapolate ,
591+ stride_1_scaled ,
592+ stride_2_scaled ,
593+ check_frequency ,
594+ ]
595+ self .lower , self .upper = self .table .build (
596+ min_nbor_dist , table_extrapolate , stride_1_scaled , stride_2_scaled
597+ )
598+
599+ self .se_ttebd .enable_compression (
600+ self .table .data , self .table_config , self .lower , self .upper
601+ )
602+ self .compress = True
603+
519604
520605@DescriptorBlock .register ("se_ttebd" )
521606class DescrptBlockSeTTebd (DescriptorBlock ):
@@ -607,6 +692,14 @@ def __init__(
607692 )
608693 self .filter_layers_strip = filter_layers_strip
609694 self .stats = None
695+ # compression related variables
696+ self .compress = False
697+ self .compress_info = nn .ParameterList (
698+ [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = "cpu" ))]
699+ )
700+ self .compress_data = nn .ParameterList (
701+ [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = env .DEVICE ))]
702+ )
610703
611704 def get_rcut (self ) -> float :
612705 """Returns the cut-off radius."""
@@ -811,6 +904,7 @@ def forward(
811904 self .rcut_smth ,
812905 protection = self .env_protection ,
813906 )
907+ # dmatrix: [1/r, dx/r^2, dy/r^2, dz/r^2], sw: distance weighting
814908 # nb x nloc x nnei
815909 exclude_mask = self .emask (nlist , extended_atype )
816910 nlist = torch .where (exclude_mask != 0 , nlist , - 1 )
@@ -831,11 +925,13 @@ def forward(
831925 rr = dmatrix
832926 rr = rr * exclude_mask [:, :, None ]
833927
834- # nfnl x nt_i x 3
928+ # nfnl x nt_i x 3: direction vectors
929+ # nt_i = nnei
930+ # nt_j = nnei
835931 rr_i = rr [:, :, 1 :]
836932 # nfnl x nt_j x 3
837933 rr_j = rr [:, :, 1 :]
838- # nfnl x nt_i x nt_j
934+ # nfnl x nt_i x nt_j: three-body angular correlations (cos theta_ij)
839935 env_ij = torch .einsum ("ijm,ikm->ijk" , rr_i , rr_j )
840936 # nfnl x nt_i x nt_j x 1
841937 ss = env_ij .unsqueeze (- 1 )
@@ -857,8 +953,24 @@ def forward(
857953 # nfnl x nt_i x nt_j x ng
858954 gg = self .filter_layers .networks [0 ](ss )
859955 elif self .tebd_input_mode in ["strip" ]:
860- # nfnl x nt_i x nt_j x ng
861- gg_s = self .filter_layers .networks [0 ](ss )
956+ if self .compress :
957+ # Tabulated geometric embedding from angular features
958+ # using SE_T_TEBD specific function
959+ ebd_env_ij = env_ij .view (- 1 , 1 )
960+ gg_s = torch .ops .deepmd .tabulate_fusion_se_t_tebd (
961+ self .compress_data [0 ].contiguous (),
962+ self .compress_info [0 ].cpu ().contiguous (),
963+ ebd_env_ij .contiguous (),
964+ env_ij .contiguous (),
965+ self .filter_neuron [- 1 ],
966+ )[0 ]
967+ # SE_T_TEBD tabulation preserves the full neighbor structure
968+ # nfnl x nt_i x nt_j x ng
969+ gg_s = gg_s .view (nfnl , nnei , nnei , self .filter_neuron [- 1 ])
970+ else :
971+ # nfnl x nt_i x nt_j x ng
972+ gg_s = self .filter_layers .networks [0 ](ss )
973+
862974 assert self .filter_layers_strip is not None
863975 assert type_embedding is not None
864976 ng = self .filter_neuron [- 1 ]
@@ -902,16 +1014,19 @@ def forward(
9021014 # (nfnl x nt_i x nt_j) x ng
9031015 gg_t = gg_t .reshape (nfnl , nnei , nnei , ng )
9041016 if self .smooth :
1017+ # Apply distance weighting to type features
9051018 gg_t = (
9061019 gg_t
9071020 * sw .reshape (nfnl , self .nnei , 1 , 1 )
9081021 * sw .reshape (nfnl , 1 , self .nnei , 1 )
9091022 )
1023+ # Combine geometric and type embeddings: gg_s * (1 + gg_t)
9101024 # nfnl x nt_i x nt_j x ng
9111025 gg = gg_s * gg_t + gg_s
9121026 else :
9131027 raise NotImplementedError
9141028
1029+ # Contract angular correlations with learned features
9151030 # nfnl x ng
9161031 res_ij = torch .einsum ("ijk,ijkm->im" , env_ij , gg )
9171032 res_ij = res_ij * (1.0 / float (self .nnei ) / float (self .nnei ))
@@ -925,6 +1040,46 @@ def forward(
9251040 sw ,
9261041 )
9271042
1043+ def enable_compression (
1044+ self ,
1045+ table_data : dict ,
1046+ table_config : dict ,
1047+ lower : dict ,
1048+ upper : dict ,
1049+ ) -> None :
1050+ """Enable compression for the SE_T_TEBD descriptor block.
1051+
1052+ Parameters
1053+ ----------
1054+ table_data : dict
1055+ The tabulated data from DPTabulate
1056+ table_config : dict
1057+ Configuration for table compression
1058+ lower : dict
1059+ Lower bounds for compression
1060+ upper : dict
1061+ Upper bounds for compression
1062+ """
1063+ # Compress the main geometric embedding network (self.filter_layers)
1064+ net_key = "filter_net"
1065+ self .compress_info [0 ] = torch .as_tensor (
1066+ [
1067+ lower [net_key ],
1068+ upper [net_key ],
1069+ upper [net_key ] * table_config [0 ],
1070+ table_config [1 ],
1071+ table_config [2 ],
1072+ table_config [3 ],
1073+ ],
1074+ dtype = self .prec ,
1075+ device = "cpu" ,
1076+ )
1077+ self .compress_data [0 ] = table_data [net_key ].to (
1078+ device = env .DEVICE , dtype = self .prec
1079+ )
1080+
1081+ self .compress = True
1082+
9281083 def has_message_passing (self ) -> bool :
9291084 """Returns whether the descriptor block has message passing."""
9301085 return False
0 commit comments