Skip to content

Commit 40e7a12

Browse files
committed
perf: vectorize type selection
1 parent e6d4638 commit 40e7a12

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

deepmd/utils/data.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,13 @@ def __init__(
102102
f"Elements {missing_elements} are not present in the provided `type_map`."
103103
)
104104
if not self.mixed_type:
105-
atom_type_ = [
106-
type_map.index(self.type_map[ii]) for ii in self.atom_type
107-
]
108-
self.atom_type = np.array(atom_type_, dtype=np.int32)
105+
# Use vectorized operation for better performance with large atom counts
106+
# Create a mapping array where old_type_idx -> new_type_idx
107+
max_old_type = max(self.atom_type) + 1
108+
type_mapping = np.zeros(max_old_type, dtype=np.int32)
109+
for old_idx in range(len(self.type_map)):
110+
type_mapping[old_idx] = type_map.index(self.type_map[old_idx])
111+
self.atom_type = type_mapping[self.atom_type].astype(np.int32)
109112
else:
110113
self.enforce_type_map = True
111114
sorter = np.argsort(type_map)
@@ -489,11 +492,9 @@ def avg(self, key: str) -> float:
489492
return np.average(eners, axis=0)
490493

491494
def _idx_map_sel(self, atom_type: np.ndarray, type_sel: list[int]) -> np.ndarray:
492-
new_types = []
493-
for ii in atom_type:
494-
if ii in type_sel:
495-
new_types.append(ii)
496-
new_types = np.array(new_types, dtype=int)
495+
# Use vectorized operations instead of Python loop
496+
sel_mask = np.isin(atom_type, type_sel)
497+
new_types = atom_type[sel_mask]
497498
natoms = new_types.shape[0]
498499
idx = np.arange(natoms, dtype=np.int64)
499500
idx_map = np.lexsort((idx, new_types))
@@ -717,9 +718,9 @@ def _load_data(
717718
idx_map = self.idx_map
718719
# if type_sel, then revise natoms and idx_map
719720
if type_sel is not None:
720-
natoms_sel = 0
721-
for jj in type_sel:
722-
natoms_sel += np.sum(self.atom_type == jj)
721+
# Use vectorized operations for better performance
722+
sel_mask = np.isin(self.atom_type, type_sel)
723+
natoms_sel = np.sum(sel_mask)
723724
idx_map_sel = self._idx_map_sel(self.atom_type, type_sel)
724725
else:
725726
natoms_sel = natoms
@@ -747,7 +748,6 @@ def _load_data(
747748
tmp = np.zeros(
748749
[nframes, natoms, ndof_], dtype=data.dtype
749750
)
750-
sel_mask = np.isin(self.atom_type, type_sel)
751751
tmp[:, sel_mask] = data.reshape(
752752
[nframes, natoms_sel, ndof_]
753753
)
@@ -760,7 +760,6 @@ def _load_data(
760760
if output_natoms_for_type_sel:
761761
pass
762762
else:
763-
sel_mask = np.isin(self.atom_type, type_sel)
764763
data = data.reshape([nframes, natoms, ndof_])
765764
data = data[:, sel_mask]
766765
natoms = natoms_sel
@@ -836,9 +835,9 @@ def _load_single_data(
836835
idx_map = self.idx_map
837836
# if type_sel, then revise natoms and idx_map
838837
if vv["type_sel"] is not None:
839-
natoms_sel = 0
840-
for jj in vv["type_sel"]:
841-
natoms_sel += np.sum(self.atom_type == jj)
838+
# Use vectorized operations for better performance
839+
sel_mask = np.isin(self.atom_type, vv["type_sel"])
840+
natoms_sel = np.sum(sel_mask)
842841
idx_map_sel = self._idx_map_sel(self.atom_type, vv["type_sel"])
843842
else:
844843
natoms_sel = natoms
@@ -885,8 +884,6 @@ def _load_single_data(
885884
if vv["atomic"]:
886885
# Handle type_sel logic
887886
if vv["type_sel"] is not None:
888-
sel_mask = np.isin(self.atom_type, vv["type_sel"])
889-
890887
if mmap_obj.shape[1] == natoms_sel * ndof:
891888
if vv["output_natoms_for_type_sel"]:
892889
tmp = np.zeros([natoms, ndof], dtype=data.dtype)

0 commit comments

Comments
 (0)