Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ def enable_compression(
self.se_atten.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)

# Enable type embedding compression
self.se_atten.type_embedding_compression(self.type_embedding)

self.compress = True

def forward(
Expand Down
113 changes: 83 additions & 30 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
MLPLayer,
NetworkCollection,
)
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
from deepmd.pt.utils import (
env,
)
Expand Down Expand Up @@ -272,7 +275,7 @@ def __init__(
self.filter_layers_strip = filter_layers_strip
self.stats = None

# add for compression
# For geometric compression
self.compress = False
self.is_sorted = False
self.compress_info = nn.ParameterList(
Expand All @@ -281,6 +284,8 @@ def __init__(
self.compress_data = nn.ParameterList(
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
)
# For type embedding compression
self.type_embd_data = None

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -447,6 +452,52 @@ def enable_compression(
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
self.compress = True

def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
"""Enable type embedding compression for strip mode.

Precomputes embedding network outputs for all type combinations:
- One-side: (ntypes+1) combinations (neighbor types only)
- Two-side: (ntypes+1)² combinations (neighbor x center type pairs)

Parameters
----------
type_embedding_net : TypeEmbedNet
The type embedding network that provides get_full_embedding() method
"""
if self.tebd_input_mode != "strip":
raise RuntimeError("Type embedding compression only works in strip mode")
if self.filter_layers_strip is None:
raise RuntimeError(
"filter_layers_strip must be initialized for type embedding compression"
)

with torch.no_grad():
# Get full type embedding: (ntypes+1) x tebd_dim
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
nt, t_dim = full_embd.shape

if self.type_one_side:
# One-side: only neighbor types, much simpler!
# Precompute for all (ntypes+1) neighbor types
self.type_embd_data = self.filter_layers_strip.networks[0](
full_embd
).detach()
else:
# Two-side: all (ntypes+1)² type pair combinations
# Create [neighbor, center] combinations
# for a fixed row i, all columns j have different neighbor types
embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
# for a fixed row i, all columns j share the same center type i
embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape(
-1, t_dim * 2
)
# Precompute for all type pairs
# Index formula: idx = center_type * nt + neighbor_type
self.type_embd_data = self.filter_layers_strip.networks[0](
two_side_embd
).detach()

def forward(
self,
nlist: torch.Tensor,
Expand Down Expand Up @@ -572,42 +623,44 @@ def forward(
nlist_index = nlist.reshape(nb, nloc * nnei)
# nf x (nl x nnei)
nei_type = torch.gather(extended_atype, dim=1, index=nlist_index)
# (nf x nl x nnei) x ng
nei_type_index = nei_type.view(-1, 1).expand(-1, ng).type(torch.long)
if self.type_one_side:
tt_full = self.filter_layers_strip.networks[0](type_embedding)
# (nf x nl x nnei) x ng
gg_t = torch.gather(tt_full, dim=0, index=nei_type_index)
if self.type_embd_data is not None:
tt_full = self.type_embd_data
else:
# (ntypes+1, tebd_dim) -> (ntypes+1, ng)
tt_full = self.filter_layers_strip.networks[0](type_embedding)
# (nf*nl*nnei,) -> (nf*nl*nnei, ng)
gg_t = tt_full[nei_type.view(-1).type(torch.long)]
else:
idx_i = torch.tile(
atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei]
).view(-1)
idx_j = nei_type.view(-1)
# (nf x nl x nnei)
idx = (idx_i + idx_j).to(torch.long)
if self.type_embd_data is not None:
# ((ntypes+1)^2, ng)
tt_full = self.type_embd_data
else:
# ((ntypes+1)^2) * (ntypes+1)^2 * nt
type_embedding_nei = torch.tile(
type_embedding.view(1, ntypes_with_padding, nt),
[ntypes_with_padding, 1, 1],
)
# (ntypes+1)^2 * ((ntypes+1)^2) * nt
type_embedding_center = torch.tile(
type_embedding.view(ntypes_with_padding, 1, nt),
[1, ntypes_with_padding, 1],
)
# ((ntypes+1)^2 * (ntypes+1)^2) * (nt+nt)
two_side_type_embedding = torch.cat(
[type_embedding_nei, type_embedding_center], -1
).reshape(-1, nt * 2)
tt_full = self.filter_layers_strip.networks[0](
two_side_type_embedding
)
# (nf x nl x nnei) x ng
idx = (
(idx_i + idx_j)
.view(-1, 1)
.expand(-1, ng)
.type(torch.long)
.to(torch.long)
)
# (ntypes) * ntypes * nt
type_embedding_nei = torch.tile(
type_embedding.view(1, ntypes_with_padding, nt),
[ntypes_with_padding, 1, 1],
)
# ntypes * (ntypes) * nt
type_embedding_center = torch.tile(
type_embedding.view(ntypes_with_padding, 1, nt),
[1, ntypes_with_padding, 1],
)
# (ntypes * ntypes) * (nt+nt)
two_side_type_embedding = torch.cat(
[type_embedding_nei, type_embedding_center], -1
).reshape(-1, nt * 2)
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
# (nf x nl x nnei) x ng
gg_t = torch.gather(tt_full, dim=0, index=idx)
gg_t = tt_full[idx]
# (nf x nl) x nnei x ng
gg_t = gg_t.reshape(nfnl, nnei, ng)
if self.smooth:
Expand Down
Loading