Skip to content

Commit 86cb76f

Browse files
authored
feat(pt): Implement type embedding compression for se_atten (#5057)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added type-embedding compression with an option to precompute and cache embeddings when compression is enabled. * **Behavioral Improvements** * Descriptor evaluation now prefers cached type embeddings to reduce repeated computation, with fallbacks to on-the-fly computation. * **Tests** * Unit tests updated to verify propagation and presence of precomputed type-embedding data. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 7f6b387 commit 86cb76f

File tree

5 files changed

+100
-30
lines changed

5 files changed

+100
-30
lines changed

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,10 @@ def enable_compression(
645645
self.se_atten.enable_compression(
646646
self.table.data, self.table_config, self.lower, self.upper
647647
)
648+
649+
# Enable type embedding compression
650+
self.se_atten.type_embedding_compression(self.type_embedding)
651+
648652
self.compress = True
649653

650654
def forward(

deepmd/pt/model/descriptor/dpa2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,4 +970,8 @@ def enable_compression(
970970
self.repinit.enable_compression(
971971
self.table.data, self.table_config, self.lower, self.upper
972972
)
973+
974+
# Enable type embedding compression for repinit (se_atten)
975+
self.repinit.type_embedding_compression(self.type_embedding)
976+
973977
self.compress = True

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
MLPLayer,
2828
NetworkCollection,
2929
)
30+
from deepmd.pt.model.network.network import (
31+
TypeEmbedNet,
32+
)
3033
from deepmd.pt.utils import (
3134
env,
3235
)
@@ -272,7 +275,7 @@ def __init__(
272275
self.filter_layers_strip = filter_layers_strip
273276
self.stats = None
274277

275-
# add for compression
278+
# For geometric compression
276279
self.compress = False
277280
self.is_sorted = False
278281
self.compress_info = nn.ParameterList(
@@ -281,6 +284,10 @@ def __init__(
281284
self.compress_data = nn.ParameterList(
282285
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
283286
)
287+
# For type embedding compression
288+
self.register_buffer(
289+
"type_embd_data", torch.zeros(0, dtype=self.prec, device=env.DEVICE)
290+
)
284291

285292
def get_rcut(self) -> float:
286293
"""Returns the cut-off radius."""
@@ -447,6 +454,56 @@ def enable_compression(
447454
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
448455
self.compress = True
449456

457+
def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
458+
"""Enable type embedding compression for strip mode.
459+
460+
Precomputes embedding network outputs for all type combinations:
461+
- One-side: (ntypes+1) combinations (neighbor types only)
462+
- Two-side: (ntypes+1)² combinations (neighbor x center type pairs)
463+
464+
Parameters
465+
----------
466+
type_embedding_net : TypeEmbedNet
467+
The type embedding network that provides get_full_embedding() method
468+
"""
469+
if self.tebd_input_mode != "strip":
470+
raise RuntimeError("Type embedding compression only works in strip mode")
471+
if self.filter_layers_strip is None:
472+
raise RuntimeError(
473+
"filter_layers_strip must be initialized for type embedding compression"
474+
)
475+
476+
with torch.no_grad():
477+
# Get full type embedding: (ntypes+1) x tebd_dim
478+
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
479+
nt, t_dim = full_embd.shape
480+
481+
if self.type_one_side:
482+
# One-side: only neighbor types, much simpler!
483+
# Precompute for all (ntypes+1) neighbor types
484+
embd_tensor = self.filter_layers_strip.networks[0](full_embd).detach()
485+
if hasattr(self, "type_embd_data"):
486+
del self.type_embd_data
487+
self.register_buffer("type_embd_data", embd_tensor)
488+
else:
489+
# Two-side: all (ntypes+1)² type pair combinations
490+
# Create [neighbor, center] combinations
491+
# for a fixed row i, all columns j have different neighbor types
492+
embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
493+
# for a fixed row i, all columns j share the same center type i
494+
embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
495+
two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape(
496+
-1, t_dim * 2
497+
)
498+
# Precompute for all type pairs
499+
# Index formula: idx = center_type * nt + neighbor_type
500+
embd_tensor = self.filter_layers_strip.networks[0](
501+
two_side_embd
502+
).detach()
503+
if hasattr(self, "type_embd_data"):
504+
del self.type_embd_data
505+
self.register_buffer("type_embd_data", embd_tensor)
506+
450507
def forward(
451508
self,
452509
nlist: torch.Tensor,
@@ -572,42 +629,44 @@ def forward(
572629
nlist_index = nlist.reshape(nb, nloc * nnei)
573630
# nf x (nl x nnei)
574631
nei_type = torch.gather(extended_atype, dim=1, index=nlist_index)
575-
# (nf x nl x nnei) x ng
576-
nei_type_index = nei_type.view(-1, 1).expand(-1, ng).type(torch.long)
577632
if self.type_one_side:
578-
tt_full = self.filter_layers_strip.networks[0](type_embedding)
579-
# (nf x nl x nnei) x ng
580-
gg_t = torch.gather(tt_full, dim=0, index=nei_type_index)
633+
if self.compress:
634+
tt_full = self.type_embd_data
635+
else:
636+
# (ntypes+1, tebd_dim) -> (ntypes+1, ng)
637+
tt_full = self.filter_layers_strip.networks[0](type_embedding)
638+
# (nf*nl*nnei,) -> (nf*nl*nnei, ng)
639+
gg_t = tt_full[nei_type.view(-1).type(torch.long)]
581640
else:
582641
idx_i = torch.tile(
583642
atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei]
584643
).view(-1)
585644
idx_j = nei_type.view(-1)
645+
# (nf x nl x nnei)
646+
idx = (idx_i + idx_j).to(torch.long)
647+
if self.compress:
648+
# ((ntypes+1)^2, ng)
649+
tt_full = self.type_embd_data
650+
else:
651+
# ((ntypes+1)^2) * (ntypes+1)^2 * nt
652+
type_embedding_nei = torch.tile(
653+
type_embedding.view(1, ntypes_with_padding, nt),
654+
[ntypes_with_padding, 1, 1],
655+
)
656+
# (ntypes+1)^2 * ((ntypes+1)^2) * nt
657+
type_embedding_center = torch.tile(
658+
type_embedding.view(ntypes_with_padding, 1, nt),
659+
[1, ntypes_with_padding, 1],
660+
)
661+
# ((ntypes+1)^2 * (ntypes+1)^2) * (nt+nt)
662+
two_side_type_embedding = torch.cat(
663+
[type_embedding_nei, type_embedding_center], -1
664+
).reshape(-1, nt * 2)
665+
tt_full = self.filter_layers_strip.networks[0](
666+
two_side_type_embedding
667+
)
586668
# (nf x nl x nnei) x ng
587-
idx = (
588-
(idx_i + idx_j)
589-
.view(-1, 1)
590-
.expand(-1, ng)
591-
.type(torch.long)
592-
.to(torch.long)
593-
)
594-
# (ntypes) * ntypes * nt
595-
type_embedding_nei = torch.tile(
596-
type_embedding.view(1, ntypes_with_padding, nt),
597-
[ntypes_with_padding, 1, 1],
598-
)
599-
# ntypes * (ntypes) * nt
600-
type_embedding_center = torch.tile(
601-
type_embedding.view(ntypes_with_padding, 1, nt),
602-
[1, ntypes_with_padding, 1],
603-
)
604-
# (ntypes * ntypes) * (nt+nt)
605-
two_side_type_embedding = torch.cat(
606-
[type_embedding_nei, type_embedding_center], -1
607-
).reshape(-1, nt * 2)
608-
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
609-
# (nf x nl x nnei) x ng
610-
gg_t = torch.gather(tt_full, dim=0, index=idx)
669+
gg_t = tt_full[idx]
611670
# (nf x nl) x nnei x ng
612671
gg_t = gg_t.reshape(nfnl, nnei, ng)
613672
if self.smooth:

source/tests/pt/model/test_descriptor_dpa1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def test_descriptor_block(self) -> None:
249249
# this is an old state dict, modify manually
250250
state_dict["compress_info.0"] = des.compress_info[0]
251251
state_dict["compress_data.0"] = des.compress_data[0]
252+
state_dict["type_embd_data"] = des.type_embd_data
252253
des.load_state_dict(state_dict)
253254
coord = self.coord
254255
atype = self.atype
@@ -377,5 +378,6 @@ def translate_se_atten_and_type_embd_dicts_to_dpa1(
377378
target_dict[tk] = type_embd_dict[kk]
378379
record[all_keys.index("se_atten.compress_data.0")] = True
379380
record[all_keys.index("se_atten.compress_info.0")] = True
381+
record[all_keys.index("se_atten.type_embd_data")] = True
380382
assert all(record)
381383
return target_dict

source/tests/pt/model/test_descriptor_dpa2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,5 +196,6 @@ def translate_type_embd_dicts_to_dpa2(
196196
target_dict[tk] = type_embd_dict[kk]
197197
record[all_keys.index("repinit.compress_data.0")] = True
198198
record[all_keys.index("repinit.compress_info.0")] = True
199+
record[all_keys.index("repinit.type_embd_data")] = True
199200
assert all(record)
200201
return target_dict

0 commit comments

Comments
 (0)