@@ -102,10 +102,13 @@ def __init__(
102102 f"Elements { missing_elements } are not present in the provided `type_map`."
103103 )
104104 if not self .mixed_type :
105- atom_type_ = [
106- type_map .index (self .type_map [ii ]) for ii in self .atom_type
107- ]
108- self .atom_type = np .array (atom_type_ , dtype = np .int32 )
105+ # Use vectorized operation for better performance with large atom counts
106+ # Create a mapping array where old_type_idx -> new_type_idx
107+ max_old_type = max (self .atom_type ) + 1
108+ type_mapping = np .zeros (max_old_type , dtype = np .int32 )
109+ for old_idx in range (len (self .type_map )):
110+ type_mapping [old_idx ] = type_map .index (self .type_map [old_idx ])
111+ self .atom_type = type_mapping [self .atom_type ].astype (np .int32 )
109112 else :
110113 self .enforce_type_map = True
111114 sorter = np .argsort (type_map )
@@ -489,11 +492,9 @@ def avg(self, key: str) -> float:
489492 return np .average (eners , axis = 0 )
490493
491494 def _idx_map_sel (self , atom_type : np .ndarray , type_sel : list [int ]) -> np .ndarray :
492- new_types = []
493- for ii in atom_type :
494- if ii in type_sel :
495- new_types .append (ii )
496- new_types = np .array (new_types , dtype = int )
495+ # Use vectorized operations instead of Python loop
496+ sel_mask = np .isin (atom_type , type_sel )
497+ new_types = atom_type [sel_mask ]
497498 natoms = new_types .shape [0 ]
498499 idx = np .arange (natoms , dtype = np .int64 )
499500 idx_map = np .lexsort ((idx , new_types ))
@@ -717,9 +718,9 @@ def _load_data(
717718 idx_map = self .idx_map
718719 # if type_sel, then revise natoms and idx_map
719720 if type_sel is not None :
720- natoms_sel = 0
721- for jj in type_sel :
722- natoms_sel + = np .sum (self . atom_type == jj )
721+ # Use vectorized operations for better performance
722+ sel_mask = np . isin ( self . atom_type , type_sel )
723+ natoms_sel = np .sum (sel_mask )
723724 idx_map_sel = self ._idx_map_sel (self .atom_type , type_sel )
724725 else :
725726 natoms_sel = natoms
@@ -747,7 +748,6 @@ def _load_data(
747748 tmp = np .zeros (
748749 [nframes , natoms , ndof_ ], dtype = data .dtype
749750 )
750- sel_mask = np .isin (self .atom_type , type_sel )
751751 tmp [:, sel_mask ] = data .reshape (
752752 [nframes , natoms_sel , ndof_ ]
753753 )
@@ -760,7 +760,6 @@ def _load_data(
760760 if output_natoms_for_type_sel :
761761 pass
762762 else :
763- sel_mask = np .isin (self .atom_type , type_sel )
764763 data = data .reshape ([nframes , natoms , ndof_ ])
765764 data = data [:, sel_mask ]
766765 natoms = natoms_sel
@@ -836,9 +835,9 @@ def _load_single_data(
836835 idx_map = self .idx_map
837836 # if type_sel, then revise natoms and idx_map
838837 if vv ["type_sel" ] is not None :
839- natoms_sel = 0
840- for jj in vv ["type_sel" ]:
841- natoms_sel + = np .sum (self . atom_type == jj )
838+ # Use vectorized operations for better performance
839+ sel_mask = np . isin ( self . atom_type , vv ["type_sel" ])
840+ natoms_sel = np .sum (sel_mask )
842841 idx_map_sel = self ._idx_map_sel (self .atom_type , vv ["type_sel" ])
843842 else :
844843 natoms_sel = natoms
@@ -885,8 +884,6 @@ def _load_single_data(
885884 if vv ["atomic" ]:
886885 # Handle type_sel logic
887886 if vv ["type_sel" ] is not None :
888- sel_mask = np .isin (self .atom_type , vv ["type_sel" ])
889-
890887 if mmap_obj .shape [1 ] == natoms_sel * ndof :
891888 if vv ["output_natoms_for_type_sel" ]:
892889 tmp = np .zeros ([natoms , ndof ], dtype = data .dtype )
0 commit comments