Skip to content

Commit 9190888

Browse files
feat(pt): add compression support for se_e3_tebd (#4992)
add compression support for se_e3_tebd <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - TEBD-style descriptor support with optional runtime compression and a Python-accessible tabulation operation that supports autograd. - **Performance Improvements** - Shared/global tabulation tables and shared embeddings to avoid redundant table builds; improved device-aware tensor handling for CPU/GPU. - **Reliability** - Validation and guards to prevent repeated compression enabling and to check input/embedding suitability. - **Tests** - New unit and integration tests covering forward/backward CPU and GPU paths. - **Documentation** - Updated docs to indicate model compression is supported. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6f00250 commit 9190888

File tree

10 files changed

+3717
-30
lines changed

10 files changed

+3717
-30
lines changed

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 159 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88

99
import torch
10+
import torch.nn as nn
1011

1112
from deepmd.dpmodel.utils import EnvMat as DPEnvMat
1213
from deepmd.dpmodel.utils.seed import (
@@ -39,9 +40,15 @@
3940
from deepmd.pt.utils.exclude_mask import (
4041
PairExcludeMask,
4142
)
43+
from deepmd.pt.utils.tabulate import (
44+
DPTabulate,
45+
)
4246
from deepmd.pt.utils.update_sel import (
4347
UpdateSel,
4448
)
49+
from deepmd.pt.utils.utils import (
50+
ActivationFn,
51+
)
4552
from deepmd.utils.data_system import (
4653
DeepmdDataSystem,
4754
)
@@ -181,6 +188,7 @@ def __init__(
181188
self.tebd_input_mode = tebd_input_mode
182189
self.concat_output_tebd = concat_output_tebd
183190
self.trainable = trainable
191+
self.compress = False
184192
# set trainable
185193
for param in self.parameters():
186194
param.requires_grad = trainable
@@ -516,6 +524,83 @@ def update_sel(
516524
local_jdata_cpy["sel"] = sel[0]
517525
return local_jdata_cpy, min_nbor_dist
518526

527+
def enable_compression(
528+
self,
529+
min_nbor_dist: float,
530+
table_extrapolate: float = 5,
531+
table_stride_1: float = 0.01,
532+
table_stride_2: float = 0.1,
533+
check_frequency: int = -1,
534+
) -> None:
535+
"""Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data.
536+
537+
Parameters
538+
----------
539+
min_nbor_dist
540+
The nearest distance between atoms
541+
table_extrapolate
542+
The scale of model extrapolation
543+
table_stride_1
544+
The uniform stride of the first table
545+
table_stride_2
546+
The uniform stride of the second table
547+
check_frequency
548+
The overflow check frequency
549+
"""
550+
# do some checks before the model compression process
551+
if self.compress:
552+
raise ValueError("Compression is already enabled.")
553+
assert not self.se_ttebd.resnet_dt, (
554+
"Model compression error: descriptor resnet_dt must be false!"
555+
)
556+
for tt in self.se_ttebd.exclude_types:
557+
if (tt[0] not in range(self.se_ttebd.ntypes)) or (
558+
tt[1] not in range(self.se_ttebd.ntypes)
559+
):
560+
raise RuntimeError(
561+
"exclude types"
562+
+ str(tt)
563+
+ " must within the number of atomic types "
564+
+ str(self.se_ttebd.ntypes)
565+
+ "!"
566+
)
567+
if (
568+
self.se_ttebd.ntypes * self.se_ttebd.ntypes
569+
- len(self.se_ttebd.exclude_types)
570+
== 0
571+
):
572+
raise RuntimeError(
573+
"Empty embedding-nets are not supported in model compression!"
574+
)
575+
576+
if self.tebd_input_mode != "strip":
577+
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")
578+
579+
data = self.serialize()
580+
self.table = DPTabulate(
581+
self,
582+
data["neuron"],
583+
exclude_types=data["exclude_types"],
584+
activation_fn=ActivationFn(data["activation_function"]),
585+
)
586+
# Scale the stride values for SE_T descriptor
587+
stride_1_scaled = table_stride_1 * 10
588+
stride_2_scaled = table_stride_2 * 10
589+
self.table_config = [
590+
table_extrapolate,
591+
stride_1_scaled,
592+
stride_2_scaled,
593+
check_frequency,
594+
]
595+
self.lower, self.upper = self.table.build(
596+
min_nbor_dist, table_extrapolate, stride_1_scaled, stride_2_scaled
597+
)
598+
599+
self.se_ttebd.enable_compression(
600+
self.table.data, self.table_config, self.lower, self.upper
601+
)
602+
self.compress = True
603+
519604

520605
@DescriptorBlock.register("se_ttebd")
521606
class DescrptBlockSeTTebd(DescriptorBlock):
@@ -607,6 +692,14 @@ def __init__(
607692
)
608693
self.filter_layers_strip = filter_layers_strip
609694
self.stats = None
695+
# compression related variables
696+
self.compress = False
697+
self.compress_info = nn.ParameterList(
698+
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
699+
)
700+
self.compress_data = nn.ParameterList(
701+
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
702+
)
610703

611704
def get_rcut(self) -> float:
612705
"""Returns the cut-off radius."""
@@ -811,6 +904,7 @@ def forward(
811904
self.rcut_smth,
812905
protection=self.env_protection,
813906
)
907+
# dmatrix: [1/r, dx/r^2, dy/r^2, dz/r^2], sw: distance weighting
814908
# nb x nloc x nnei
815909
exclude_mask = self.emask(nlist, extended_atype)
816910
nlist = torch.where(exclude_mask != 0, nlist, -1)
@@ -831,11 +925,13 @@ def forward(
831925
rr = dmatrix
832926
rr = rr * exclude_mask[:, :, None]
833927

834-
# nfnl x nt_i x 3
928+
# nfnl x nt_i x 3: direction vectors
929+
# nt_i = nnei
930+
# nt_j = nnei
835931
rr_i = rr[:, :, 1:]
836932
# nfnl x nt_j x 3
837933
rr_j = rr[:, :, 1:]
838-
# nfnl x nt_i x nt_j
934+
# nfnl x nt_i x nt_j: three-body angular correlations (cos theta_ij)
839935
env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j)
840936
# nfnl x nt_i x nt_j x 1
841937
ss = env_ij.unsqueeze(-1)
@@ -857,8 +953,24 @@ def forward(
857953
# nfnl x nt_i x nt_j x ng
858954
gg = self.filter_layers.networks[0](ss)
859955
elif self.tebd_input_mode in ["strip"]:
860-
# nfnl x nt_i x nt_j x ng
861-
gg_s = self.filter_layers.networks[0](ss)
956+
if self.compress:
957+
# Tabulated geometric embedding from angular features
958+
# using SE_T_TEBD specific function
959+
ebd_env_ij = env_ij.view(-1, 1)
960+
gg_s = torch.ops.deepmd.tabulate_fusion_se_t_tebd(
961+
self.compress_data[0].contiguous(),
962+
self.compress_info[0].cpu().contiguous(),
963+
ebd_env_ij.contiguous(),
964+
env_ij.contiguous(),
965+
self.filter_neuron[-1],
966+
)[0]
967+
# SE_T_TEBD tabulation preserves the full neighbor structure
968+
# nfnl x nt_i x nt_j x ng
969+
gg_s = gg_s.view(nfnl, nnei, nnei, self.filter_neuron[-1])
970+
else:
971+
# nfnl x nt_i x nt_j x ng
972+
gg_s = self.filter_layers.networks[0](ss)
973+
862974
assert self.filter_layers_strip is not None
863975
assert type_embedding is not None
864976
ng = self.filter_neuron[-1]
@@ -902,16 +1014,19 @@ def forward(
9021014
# (nfnl x nt_i x nt_j) x ng
9031015
gg_t = gg_t.reshape(nfnl, nnei, nnei, ng)
9041016
if self.smooth:
1017+
# Apply distance weighting to type features
9051018
gg_t = (
9061019
gg_t
9071020
* sw.reshape(nfnl, self.nnei, 1, 1)
9081021
* sw.reshape(nfnl, 1, self.nnei, 1)
9091022
)
1023+
# Combine geometric and type embeddings: gg_s * (1 + gg_t)
9101024
# nfnl x nt_i x nt_j x ng
9111025
gg = gg_s * gg_t + gg_s
9121026
else:
9131027
raise NotImplementedError
9141028

1029+
# Contract angular correlations with learned features
9151030
# nfnl x ng
9161031
res_ij = torch.einsum("ijk,ijkm->im", env_ij, gg)
9171032
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
@@ -925,6 +1040,46 @@ def forward(
9251040
sw,
9261041
)
9271042

1043+
def enable_compression(
1044+
self,
1045+
table_data: dict,
1046+
table_config: dict,
1047+
lower: dict,
1048+
upper: dict,
1049+
) -> None:
1050+
"""Enable compression for the SE_T_TEBD descriptor block.
1051+
1052+
Parameters
1053+
----------
1054+
table_data : dict
1055+
The tabulated data from DPTabulate
1056+
table_config : dict
1057+
Configuration for table compression
1058+
lower : dict
1059+
Lower bounds for compression
1060+
upper : dict
1061+
Upper bounds for compression
1062+
"""
1063+
# Compress the main geometric embedding network (self.filter_layers)
1064+
net_key = "filter_net"
1065+
self.compress_info[0] = torch.as_tensor(
1066+
[
1067+
lower[net_key],
1068+
upper[net_key],
1069+
upper[net_key] * table_config[0],
1070+
table_config[1],
1071+
table_config[2],
1072+
table_config[3],
1073+
],
1074+
dtype=self.prec,
1075+
device="cpu",
1076+
)
1077+
self.compress_data[0] = table_data[net_key].to(
1078+
device=env.DEVICE, dtype=self.prec
1079+
)
1080+
1081+
self.compress = True
1082+
9281083
def has_message_passing(self) -> bool:
9291084
"""Returns whether the descriptor block has message passing."""
9301085
return False

deepmd/pt/utils/tabulate.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,7 @@ def __init__(
6666
)
6767
self.descrpt_type = self._get_descrpt_type()
6868

69-
supported_descrpt_type = (
70-
"Atten",
71-
"A",
72-
"T",
73-
"R",
74-
)
69+
supported_descrpt_type = ("Atten", "A", "T", "T_TEBD", "R")
7570

7671
if self.descrpt_type in supported_descrpt_type:
7772
self.sel_a = self.descrpt.get_sel()
@@ -156,7 +151,7 @@ def _make_data(self, xx: np.ndarray, idx: int) -> Any:
156151
self.matrix["layer_" + str(layer + 1)][idx],
157152
xbar,
158153
self.functype,
159-
) + torch.ones((1, 1), dtype=yy.dtype) # pylint: disable=no-explicit-device
154+
) + torch.ones((1, 1), dtype=yy.dtype, device=yy.device)
160155
dy2 = unaggregated_dy2_dx_s(
161156
yy - xx,
162157
dy,
@@ -175,7 +170,7 @@ def _make_data(self, xx: np.ndarray, idx: int) -> Any:
175170
self.matrix["layer_" + str(layer + 1)][idx],
176171
xbar,
177172
self.functype,
178-
) + torch.ones((1, 2), dtype=yy.dtype) # pylint: disable=no-explicit-device
173+
) + torch.ones((1, 2), dtype=yy.dtype, device=yy.device)
179174
dy2 = unaggregated_dy2_dx_s(
180175
yy - tt,
181176
dy,
@@ -311,6 +306,8 @@ def _get_descrpt_type(self) -> str:
311306
return "R"
312307
elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeT):
313308
return "T"
309+
elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeTTebd):
310+
return "T_TEBD"
314311
raise RuntimeError(f"Unsupported descriptor {self.descrpt}")
315312

316313
def _get_layer_size(self) -> int:
@@ -325,7 +322,7 @@ def _get_layer_size(self) -> int:
325322
* len(self.embedding_net_nodes[0])
326323
* len(self.neuron)
327324
)
328-
if self.descrpt_type == "Atten":
325+
if self.descrpt_type in ("Atten", "T_TEBD"):
329326
layer_size = len(self.embedding_net_nodes[0]["layers"])
330327
elif self.descrpt_type == "A":
331328
layer_size = len(self.embedding_net_nodes[0]["layers"])
@@ -394,6 +391,13 @@ def _get_network_variable(self, var_name: str) -> dict:
394391
"layers"
395392
][layer - 1]["@variables"][var_name]
396393
result["layer_" + str(layer)].append(node)
394+
elif self.descrpt_type == "T_TEBD":
395+
# For the se_e3_tebd descriptor, a single,
396+
# shared embedding network is used for all type pairs
397+
node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][
398+
var_name
399+
]
400+
result["layer_" + str(layer)].append(node)
397401
elif self.descrpt_type == "R":
398402
if self.type_one_side:
399403
for ii in range(0, self.ntypes):

0 commit comments

Comments
 (0)