2727 MLPLayer ,
2828 NetworkCollection ,
2929)
30+ from deepmd .pt .model .network .network import (
31+ TypeEmbedNet ,
32+ )
3033from deepmd .pt .utils import (
3134 env ,
3235)
@@ -272,7 +275,7 @@ def __init__(
272275 self .filter_layers_strip = filter_layers_strip
273276 self .stats = None
274277
275- # add for compression
278+ # For geometric compression
276279 self .compress = False
277280 self .is_sorted = False
278281 self .compress_info = nn .ParameterList (
@@ -281,6 +284,10 @@ def __init__(
281284 self .compress_data = nn .ParameterList (
282285 [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = env .DEVICE ))]
283286 )
287+ # For type embedding compression
288+ self .register_buffer (
289+ "type_embd_data" , torch .zeros (0 , dtype = self .prec , device = env .DEVICE )
290+ )
284291
285292 def get_rcut (self ) -> float :
286293 """Returns the cut-off radius."""
@@ -447,6 +454,56 @@ def enable_compression(
447454 self .compress_data [0 ] = table_data [net ].to (device = env .DEVICE , dtype = self .prec )
448455 self .compress = True
449456
457+ def type_embedding_compression (self , type_embedding_net : TypeEmbedNet ) -> None :
458+ """Enable type embedding compression for strip mode.
459+
460+ Precomputes embedding network outputs for all type combinations:
461+ - One-side: (ntypes+1) combinations (neighbor types only)
462+ - Two-side: (ntypes+1)² combinations (neighbor x center type pairs)
463+
464+ Parameters
465+ ----------
466+ type_embedding_net : TypeEmbedNet
467+ The type embedding network that provides get_full_embedding() method
468+ """
469+ if self .tebd_input_mode != "strip" :
470+ raise RuntimeError ("Type embedding compression only works in strip mode" )
471+ if self .filter_layers_strip is None :
472+ raise RuntimeError (
473+ "filter_layers_strip must be initialized for type embedding compression"
474+ )
475+
476+ with torch .no_grad ():
477+ # Get full type embedding: (ntypes+1) x tebd_dim
478+ full_embd = type_embedding_net .get_full_embedding (env .DEVICE )
479+ nt , t_dim = full_embd .shape
480+
481+ if self .type_one_side :
482+ # One-side: only neighbor types, much simpler!
483+ # Precompute for all (ntypes+1) neighbor types
484+ embd_tensor = self .filter_layers_strip .networks [0 ](full_embd ).detach ()
485+ if hasattr (self , "type_embd_data" ):
486+ del self .type_embd_data
487+ self .register_buffer ("type_embd_data" , embd_tensor )
488+ else :
489+ # Two-side: all (ntypes+1)² type pair combinations
490+ # Create [neighbor, center] combinations
491+ # for a fixed row i, all columns j have different neighbor types
492+ embd_nei = full_embd .view (1 , nt , t_dim ).expand (nt , nt , t_dim )
493+ # for a fixed row i, all columns j share the same center type i
494+ embd_center = full_embd .view (nt , 1 , t_dim ).expand (nt , nt , t_dim )
495+ two_side_embd = torch .cat ([embd_nei , embd_center ], dim = - 1 ).reshape (
496+ - 1 , t_dim * 2
497+ )
498+ # Precompute for all type pairs
499+ # Index formula: idx = center_type * nt + neighbor_type
500+ embd_tensor = self .filter_layers_strip .networks [0 ](
501+ two_side_embd
502+ ).detach ()
503+ if hasattr (self , "type_embd_data" ):
504+ del self .type_embd_data
505+ self .register_buffer ("type_embd_data" , embd_tensor )
506+
450507 def forward (
451508 self ,
452509 nlist : torch .Tensor ,
@@ -572,42 +629,44 @@ def forward(
572629 nlist_index = nlist .reshape (nb , nloc * nnei )
573630 # nf x (nl x nnei)
574631 nei_type = torch .gather (extended_atype , dim = 1 , index = nlist_index )
575- # (nf x nl x nnei) x ng
576- nei_type_index = nei_type .view (- 1 , 1 ).expand (- 1 , ng ).type (torch .long )
577632 if self .type_one_side :
578- tt_full = self .filter_layers_strip .networks [0 ](type_embedding )
579- # (nf x nl x nnei) x ng
580- gg_t = torch .gather (tt_full , dim = 0 , index = nei_type_index )
633+ if self .compress :
634+ tt_full = self .type_embd_data
635+ else :
636+ # (ntypes+1, tebd_dim) -> (ntypes+1, ng)
637+ tt_full = self .filter_layers_strip .networks [0 ](type_embedding )
638+ # (nf*nl*nnei,) -> (nf*nl*nnei, ng)
639+ gg_t = tt_full [nei_type .view (- 1 ).type (torch .long )]
581640 else :
582641 idx_i = torch .tile (
583642 atype .reshape (- 1 , 1 ) * ntypes_with_padding , [1 , nnei ]
584643 ).view (- 1 )
585644 idx_j = nei_type .view (- 1 )
645+ # (nf x nl x nnei)
646+ idx = (idx_i + idx_j ).to (torch .long )
647+ if self .compress :
648+ # ((ntypes+1)^2, ng)
649+ tt_full = self .type_embd_data
650+ else :
651+ # ((ntypes+1)^2) * (ntypes+1)^2 * nt
652+ type_embedding_nei = torch .tile (
653+ type_embedding .view (1 , ntypes_with_padding , nt ),
654+ [ntypes_with_padding , 1 , 1 ],
655+ )
656+ # (ntypes+1)^2 * ((ntypes+1)^2) * nt
657+ type_embedding_center = torch .tile (
658+ type_embedding .view (ntypes_with_padding , 1 , nt ),
659+ [1 , ntypes_with_padding , 1 ],
660+ )
661+ # ((ntypes+1)^2 * (ntypes+1)^2) * (nt+nt)
662+ two_side_type_embedding = torch .cat (
663+ [type_embedding_nei , type_embedding_center ], - 1
664+ ).reshape (- 1 , nt * 2 )
665+ tt_full = self .filter_layers_strip .networks [0 ](
666+ two_side_type_embedding
667+ )
586668 # (nf x nl x nnei) x ng
587- idx = (
588- (idx_i + idx_j )
589- .view (- 1 , 1 )
590- .expand (- 1 , ng )
591- .type (torch .long )
592- .to (torch .long )
593- )
594- # (ntypes) * ntypes * nt
595- type_embedding_nei = torch .tile (
596- type_embedding .view (1 , ntypes_with_padding , nt ),
597- [ntypes_with_padding , 1 , 1 ],
598- )
599- # ntypes * (ntypes) * nt
600- type_embedding_center = torch .tile (
601- type_embedding .view (ntypes_with_padding , 1 , nt ),
602- [1 , ntypes_with_padding , 1 ],
603- )
604- # (ntypes * ntypes) * (nt+nt)
605- two_side_type_embedding = torch .cat (
606- [type_embedding_nei , type_embedding_center ], - 1
607- ).reshape (- 1 , nt * 2 )
608- tt_full = self .filter_layers_strip .networks [0 ](two_side_type_embedding )
609- # (nf x nl x nnei) x ng
610- gg_t = torch .gather (tt_full , dim = 0 , index = idx )
669+ gg_t = tt_full [idx ]
611670 # (nf x nl) x nnei x ng
612671 gg_t = gg_t .reshape (nfnl , nnei , ng )
613672 if self .smooth :
0 commit comments