From bd855c5342f5006465ded463b540a696a2dc9b14 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 15:53:06 +0800 Subject: [PATCH 01/17] perf: realize memmap to accelarate dataloader --- deepmd/pt/utils/dataset.py | 2 +- deepmd/utils/data.py | 213 +++++++++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 2cbe47cc3e..cf7b0b8e48 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -37,7 +37,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> dict[str, Any]: """Get a frame from the selected system.""" - b_data = self._data_system.get_item_torch(index) + b_data = self._data_system._get_single_frame(index) b_data["natoms"] = self._natoms_vec return b_data diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 9b93c64507..2d9a88d95e 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -3,6 +3,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import bisect import logging +from pathlib import ( + Path, +) from typing import ( Any, Optional, @@ -128,6 +131,8 @@ def __init__( self.shuffle_test = shuffle_test # set modifier self.modifier = modifier + # Add a cache for memory-mapped files + self.memmap_cache = {} # calculate prefix sum for get_item method frames_list = [self._get_nframes(item) for item in self.dirs] self.nframes = np.sum(frames_list) @@ -394,6 +399,214 @@ def avg(self, key: str) -> float: else: return np.average(eners, axis=0) + def _get_memmap(self, path: DPPath) -> np.memmap: + """Get or create a memory-mapped object for a given npy file.""" + memmap_key = Path(str(path)).absolute() + if memmap_key not in self.memmap_cache: + # Open the npy file to read its header and get shape/dtype + with open(str(path), "rb") as f: + version = np.lib.format.read_magic(f) + shape, fortran_order, dtype = np.lib.format._read_array_header( + f, version + ) + offset = f.tell() + order = "F" if fortran_order else "C" + # Create a read-only memmap and cache it + self.memmap_cache[memmap_key] = np.memmap( + str(path), + dtype=dtype, + mode="r", + shape=shape, + order=order, + offset=offset, + ) + return self.memmap_cache[memmap_key] + + def _get_single_frame(self, index: int) -> dict: + """Orchestrates loading a single frame efficiently using memmap.""" + # 1. Find the correct set directory and local frame index + set_idx = bisect.bisect_right(self.prefix_sum, index) + set_dir = self.dirs[set_idx] + if not isinstance(set_dir, DPPath): + set_dir = DPPath(set_dir) + # Calculate local index within the set.* directory + local_idx = index - self.prefix_sum[set_idx] + + frame_data = {} + # 2. Load all non-reduced items first + # TODO: use async + for key, vv in self.data_dict.items(): + if vv["reduce"] is None: + frame_data["find_" + key], frame_data[key] = self._load_single_item( + set_dir, key, local_idx + ) + + # 3. Compute reduced items from already loaded data + # TODO: use async + for key, vv in self.data_dict.items(): + if vv["reduce"] is not None: + k_in = vv["reduce"] + ndof = vv["ndof"] + frame_data["find_" + key] = frame_data["find_" + k_in] + # Reshape to (natoms, ndof) and sum over atom axis + tmp_in = ( + frame_data[k_in] + .reshape(-1, ndof) + .astype(GLOBAL_ENER_FLOAT_PRECISION) + ) + frame_data[key] = np.sum(tmp_in, axis=0) + + # 4. Handle atom types (mixed or standard) + # TODO: mixed_type + if self.mixed_type: + type_path = set_dir / "real_atom_types.npy" + mmap_types = self._get_memmap(type_path) + real_type = mmap_types[local_idx].copy().astype(np.int32) + + if self.enforce_type_map: + real_type = self.type_idx_map[real_type].astype(np.int32) + + frame_data["type"] = real_type + ntypes = self.get_ntypes() + natoms = len(real_type) + # Use bincount for efficient counting of each type + natoms_vec = np.bincount( + real_type[real_type >= 0], minlength=ntypes + ).astype(np.int32) + frame_data["real_natoms_vec"] = np.concatenate( + (np.array([natoms, natoms], dtype=np.int32), natoms_vec) + ) + else: + frame_data["type"] = self.atom_type[self.idx_map] + + # 5. Standardize keys + frame_data = {kk.replace("atomic", "atom"): vv for kk, vv in frame_data.items()} + + # 6. Reshape atomic data to match expected format [natoms, ndof] + for kk in self.data_dict.keys(): + if "find_" in kk: + pass + else: + if kk in frame_data and not self.data_dict[kk]["atomic"]: + frame_data[kk] = frame_data[kk].reshape(-1) + frame_data["atype"] = frame_data["type"] + + if not self.pbc: + frame_data["box"] = None + + frame_data["fid"] = index + return frame_data + + def _load_single_item( + self, set_dir: DPPath, key: str, frame_idx: int + ) -> tuple[np.float32, np.ndarray]: + """ + Loads and processes data for a SINGLE frame from a SINGLE key, + fully replicating the logic from the original _load_data method. + """ + vv = self.data_dict[key] + path = set_dir / (key + ".npy") + + if vv["atomic"]: + natoms = self.natoms + idx_map = self.idx_map + # if type_sel, then revise natoms and idx_map + if vv["type_sel"] is not None: + natoms_sel = 0 + for jj in vv["type_sel"]: + natoms_sel += np.sum(self.atom_type == jj) + idx_map_sel = self._idx_map_sel(self.atom_type, vv["type_sel"]) + else: + natoms_sel = natoms + idx_map_sel = idx_map + else: + natoms = 1 + natoms_sel = 0 + idx_map_sel = None + ndof = vv["ndof"] + + # Determine target data type from requirements + dtype = vv.get("dtype") + if dtype is None: + dtype = ( + GLOBAL_ENER_FLOAT_PRECISION + if vv.get("high_prec") + else GLOBAL_NP_FLOAT_PRECISION + ) + + # Branch 1: File does not exist + if not path.is_file(): + if vv.get("must"): + raise RuntimeError(f"{path} not found!") + + # Create a default array based on requirements + if ( + vv["atomic"] + and vv["type_sel"] is not None + and not vv["output_natoms_for_type_sel"] + ): + natoms = natoms_sel + data = np.full([natoms, ndof], vv["default"], dtype=dtype) + return np.float32(0.0), data + + # Branch 2: File exists, use memmap + mmap_obj = self._get_memmap(path) + # Slice the single frame and make an in-memory copy for modification + data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) + + try: + if vv["atomic"]: + # Handle type_sel logic + if vv["type_sel"] is not None: + sel_mask = np.isin(self.atom_type, vv["type_sel"]) + + if mmap_obj.shape[1] == natoms_sel * ndof: + if vv["output_natoms_for_type_sel"]: + tmp = np.zeros([natoms, ndof], dtype=data.dtype) + # sel_mask needs to be applied to the original atom layout + tmp[sel_mask] = data.reshape([natoms_sel, ndof]) + data = tmp + else: # output is natoms_sel + natoms = natoms_sel + idx_map = idx_map_sel + elif mmap_obj.shape[1] == natoms * ndof: + data = data.reshape([natoms, ndof]) + if vv["output_natoms_for_type_sel"]: + pass + else: + data = data[sel_mask] + idx_map = idx_map_sel + natoms = natoms_sel + else: # Shape mismatch error + raise ValueError( + f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})" + ) + + # Handle special case for Hessian + if key == "hessian": + data = data.reshape(3 * natoms, 3 * natoms) + num_chunks, chunk_size = len(idx_map), 3 + idx_map_hess = np.arange( + num_chunks * chunk_size, dtype=int + ).reshape(num_chunks, chunk_size) + idx_map_hess = idx_map_hess[idx_map].flatten() + data = data[idx_map_hess, :] + data = data[:, idx_map_hess] + data = data.reshape(-1) + ndof = 3 * ndof * 3 * ndof # size of hessian is 3Natoms * 3Natoms + else: + # data should be 2D here: (natoms, ndof) + data = data.reshape([natoms, -1]) + data = data[idx_map, :] + + return np.float32(1.0), data + + except ValueError as err_message: + explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." + log.error(str(err_message)) + log.error(explanation) + raise ValueError(str(err_message) + ". " + explanation) from err_message + def _idx_map_sel(self, atom_type: np.ndarray, type_sel: list[int]) -> np.ndarray: new_types = [] for ii in atom_type: From 9537ef62caf9e24cf9c22080ec8512946aa35055 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 20 Oct 2025 10:46:51 +0800 Subject: [PATCH 02/17] fix: mix type --- deepmd/pd/utils/dataset.py | 2 +- deepmd/pt/utils/dataset.py | 2 +- deepmd/utils/data.py | 323 ++++++++++++++++++------------------- 3 files changed, 161 insertions(+), 166 deletions(-) diff --git a/deepmd/pd/utils/dataset.py b/deepmd/pd/utils/dataset.py index 1f0533d8fc..e2885e340e 100644 --- a/deepmd/pd/utils/dataset.py +++ b/deepmd/pd/utils/dataset.py @@ -36,7 +36,7 @@ def __len__(self): def __getitem__(self, index): """Get a frame from the selected system.""" - b_data = self._data_system.get_item_paddle(index) + b_data = self._data_system.get_frame_paddle(index) b_data["natoms"] = self._natoms_vec return b_data diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index cf7b0b8e48..2cbe47cc3e 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -37,7 +37,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> dict[str, Any]: """Get a frame from the selected system.""" - b_data = self._data_system._get_single_frame(index) + b_data = self._data_system.get_item_torch(index) b_data["natoms"] = self._natoms_vec return b_data diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 2d9a88d95e..a895402cee 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -253,27 +253,18 @@ def get_item_torch(self, index: int) -> dict: index index of the frame """ - i = bisect.bisect_right(self.prefix_sum, index) - frames = self._load_set(self.dirs[i]) - frame = self._get_subdata(frames, index - self.prefix_sum[i]) - frame = self.reformat_data_torch(frame) - frame["fid"] = index - return frame + return self.get_single_frame(index) def get_item_paddle(self, index: int) -> dict: """Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets. + Same with PyTorch backend. Parameters ---------- index index of the frame """ - i = bisect.bisect_right(self.prefix_sum, index) - frames = self._load_set(self.dirs[i]) - frame = self._get_subdata(frames, index - self.prefix_sum[i]) - frame = self.reformat_data_torch(frame) - frame["fid"] = index - return frame + return self.get_single_frame(index) def get_batch(self, batch_size: int) -> dict: """Get a batch of data with `batch_size` frames. The frames are randomly picked from the data system. @@ -382,47 +373,7 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray: tmp = np.append(tmp, natoms_vec) return tmp.astype(np.int32) - def avg(self, key: str) -> float: - """Return the average value of an item.""" - if key not in self.data_dict.keys(): - raise RuntimeError(f"key {key} has not been added") - info = self.data_dict[key] - ndof = info["ndof"] - eners = [] - for ii in self.dirs: - data = self._load_set(ii) - ei = data[key].reshape([-1, ndof]) - eners.append(ei) - eners = np.concatenate(eners, axis=0) - if eners.size == 0: - return 0 - else: - return np.average(eners, axis=0) - - def _get_memmap(self, path: DPPath) -> np.memmap: - """Get or create a memory-mapped object for a given npy file.""" - memmap_key = Path(str(path)).absolute() - if memmap_key not in self.memmap_cache: - # Open the npy file to read its header and get shape/dtype - with open(str(path), "rb") as f: - version = np.lib.format.read_magic(f) - shape, fortran_order, dtype = np.lib.format._read_array_header( - f, version - ) - offset = f.tell() - order = "F" if fortran_order else "C" - # Create a read-only memmap and cache it - self.memmap_cache[memmap_key] = np.memmap( - str(path), - dtype=dtype, - mode="r", - shape=shape, - order=order, - offset=offset, - ) - return self.memmap_cache[memmap_key] - - def _get_single_frame(self, index: int) -> dict: + def get_single_frame(self, index: int) -> dict: """Orchestrates loading a single frame efficiently using memmap.""" # 1. Find the correct set directory and local frame index set_idx = bisect.bisect_right(self.prefix_sum, index) @@ -437,7 +388,7 @@ def _get_single_frame(self, index: int) -> dict: # TODO: use async for key, vv in self.data_dict.items(): if vv["reduce"] is None: - frame_data["find_" + key], frame_data[key] = self._load_single_item( + frame_data["find_" + key], frame_data[key] = self._load_single_data( set_dir, key, local_idx ) @@ -457,14 +408,18 @@ def _get_single_frame(self, index: int) -> dict: frame_data[key] = np.sum(tmp_in, axis=0) # 4. Handle atom types (mixed or standard) - # TODO: mixed_type if self.mixed_type: type_path = set_dir / "real_atom_types.npy" mmap_types = self._get_memmap(type_path) real_type = mmap_types[local_idx].copy().astype(np.int32) if self.enforce_type_map: - real_type = self.type_idx_map[real_type].astype(np.int32) + try: + real_type = self.type_idx_map[real_type].astype(np.int32) + except IndexError as e: + raise IndexError( + f"some types in 'real_atom_types.npy' of set {set_dir} are not contained in {self.get_ntypes()} types!" + ) from e frame_data["type"] = real_type ntypes = self.get_ntypes() @@ -497,115 +452,22 @@ def _get_single_frame(self, index: int) -> dict: frame_data["fid"] = index return frame_data - def _load_single_item( - self, set_dir: DPPath, key: str, frame_idx: int - ) -> tuple[np.float32, np.ndarray]: - """ - Loads and processes data for a SINGLE frame from a SINGLE key, - fully replicating the logic from the original _load_data method. - """ - vv = self.data_dict[key] - path = set_dir / (key + ".npy") - - if vv["atomic"]: - natoms = self.natoms - idx_map = self.idx_map - # if type_sel, then revise natoms and idx_map - if vv["type_sel"] is not None: - natoms_sel = 0 - for jj in vv["type_sel"]: - natoms_sel += np.sum(self.atom_type == jj) - idx_map_sel = self._idx_map_sel(self.atom_type, vv["type_sel"]) - else: - natoms_sel = natoms - idx_map_sel = idx_map + def avg(self, key: str) -> float: + """Return the average value of an item.""" + if key not in self.data_dict.keys(): + raise RuntimeError(f"key {key} has not been added") + info = self.data_dict[key] + ndof = info["ndof"] + eners = [] + for ii in self.dirs: + data = self._load_set(ii) + ei = data[key].reshape([-1, ndof]) + eners.append(ei) + eners = np.concatenate(eners, axis=0) + if eners.size == 0: + return 0 else: - natoms = 1 - natoms_sel = 0 - idx_map_sel = None - ndof = vv["ndof"] - - # Determine target data type from requirements - dtype = vv.get("dtype") - if dtype is None: - dtype = ( - GLOBAL_ENER_FLOAT_PRECISION - if vv.get("high_prec") - else GLOBAL_NP_FLOAT_PRECISION - ) - - # Branch 1: File does not exist - if not path.is_file(): - if vv.get("must"): - raise RuntimeError(f"{path} not found!") - - # Create a default array based on requirements - if ( - vv["atomic"] - and vv["type_sel"] is not None - and not vv["output_natoms_for_type_sel"] - ): - natoms = natoms_sel - data = np.full([natoms, ndof], vv["default"], dtype=dtype) - return np.float32(0.0), data - - # Branch 2: File exists, use memmap - mmap_obj = self._get_memmap(path) - # Slice the single frame and make an in-memory copy for modification - data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) - - try: - if vv["atomic"]: - # Handle type_sel logic - if vv["type_sel"] is not None: - sel_mask = np.isin(self.atom_type, vv["type_sel"]) - - if mmap_obj.shape[1] == natoms_sel * ndof: - if vv["output_natoms_for_type_sel"]: - tmp = np.zeros([natoms, ndof], dtype=data.dtype) - # sel_mask needs to be applied to the original atom layout - tmp[sel_mask] = data.reshape([natoms_sel, ndof]) - data = tmp - else: # output is natoms_sel - natoms = natoms_sel - idx_map = idx_map_sel - elif mmap_obj.shape[1] == natoms * ndof: - data = data.reshape([natoms, ndof]) - if vv["output_natoms_for_type_sel"]: - pass - else: - data = data[sel_mask] - idx_map = idx_map_sel - natoms = natoms_sel - else: # Shape mismatch error - raise ValueError( - f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})" - ) - - # Handle special case for Hessian - if key == "hessian": - data = data.reshape(3 * natoms, 3 * natoms) - num_chunks, chunk_size = len(idx_map), 3 - idx_map_hess = np.arange( - num_chunks * chunk_size, dtype=int - ).reshape(num_chunks, chunk_size) - idx_map_hess = idx_map_hess[idx_map].flatten() - data = data[idx_map_hess, :] - data = data[:, idx_map_hess] - data = data.reshape(-1) - ndof = 3 * ndof * 3 * ndof # size of hessian is 3Natoms * 3Natoms - else: - # data should be 2D here: (natoms, ndof) - data = data.reshape([natoms, -1]) - data = data[idx_map, :] - - return np.float32(1.0), data - - except ValueError as err_message: - explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." - log.error(str(err_message)) - log.error(explanation) - raise ValueError(str(err_message) + ". " + explanation) from err_message + return np.average(eners, axis=0) def _idx_map_sel(self, atom_type: np.ndarray, type_sel: list[int]) -> np.ndarray: new_types = [] @@ -626,6 +488,29 @@ def _get_natoms_2(self, ntypes: int) -> tuple[int, np.ndarray]: natoms_vec[ii] = np.count_nonzero(sample_type == ii) return natoms, natoms_vec + def _get_memmap(self, path: DPPath) -> np.memmap: + """Get or create a memory-mapped object for a given npy file.""" + memmap_key = Path(str(path)).absolute() + if memmap_key not in self.memmap_cache: + # Open the npy file to read its header and get shape/dtype + with open(str(path), "rb") as f: + version = np.lib.format.read_magic(f) + shape, fortran_order, dtype = np.lib.format._read_array_header( + f, version + ) + offset = f.tell() + order = "F" if fortran_order else "C" + # Create a read-only memmap and cache it + self.memmap_cache[memmap_key] = np.memmap( + str(path), + dtype=dtype, + mode="r", + shape=shape, + order=order, + offset=offset, + ) + return self.memmap_cache[memmap_key] + def _get_subdata( self, data: dict[str, Any], idx: Optional[np.ndarray] = None ) -> dict[str, Any]: @@ -920,6 +805,116 @@ def _load_data( data = np.repeat(data, repeat).reshape([nframes, -1]) return np.float32(0.0), data + def _load_single_data( + self, set_dir: DPPath, key: str, frame_idx: int + ) -> tuple[np.float32, np.ndarray]: + """ + Loads and processes data for a SINGLE frame from a SINGLE key, + fully replicating the logic from the original _load_data method. + """ + vv = self.data_dict[key] + path = set_dir / (key + ".npy") + + if vv["atomic"]: + natoms = self.natoms + idx_map = self.idx_map + # if type_sel, then revise natoms and idx_map + if vv["type_sel"] is not None: + natoms_sel = 0 + for jj in vv["type_sel"]: + natoms_sel += np.sum(self.atom_type == jj) + idx_map_sel = self._idx_map_sel(self.atom_type, vv["type_sel"]) + else: + natoms_sel = natoms + idx_map_sel = idx_map + else: + natoms = 1 + natoms_sel = 0 + idx_map_sel = None + ndof = vv["ndof"] + + # Determine target data type from requirements + dtype = vv.get("dtype") + if dtype is None: + dtype = ( + GLOBAL_ENER_FLOAT_PRECISION + if vv.get("high_prec") + else GLOBAL_NP_FLOAT_PRECISION + ) + + # Branch 1: File does not exist + if not path.is_file(): + if vv.get("must"): + raise RuntimeError(f"{path} not found!") + + # Create a default array based on requirements + if ( + vv["atomic"] + and vv["type_sel"] is not None + and not vv["output_natoms_for_type_sel"] + ): + natoms = natoms_sel + data = np.full([natoms, ndof], vv["default"], dtype=dtype) + return np.float32(0.0), data + + # Branch 2: File exists, use memmap + mmap_obj = self._get_memmap(path) + # Slice the single frame and make an in-memory copy for modification + data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) + + try: + if vv["atomic"]: + # Handle type_sel logic + if vv["type_sel"] is not None: + sel_mask = np.isin(self.atom_type, vv["type_sel"]) + + if mmap_obj.shape[1] == natoms_sel * ndof: + if vv["output_natoms_for_type_sel"]: + tmp = np.zeros([natoms, ndof], dtype=data.dtype) + # sel_mask needs to be applied to the original atom layout + tmp[sel_mask] = data.reshape([natoms_sel, ndof]) + data = tmp + else: # output is natoms_sel + natoms = natoms_sel + idx_map = idx_map_sel + elif mmap_obj.shape[1] == natoms * ndof: + data = data.reshape([natoms, ndof]) + if vv["output_natoms_for_type_sel"]: + pass + else: + data = data[sel_mask] + idx_map = idx_map_sel + natoms = natoms_sel + else: # Shape mismatch error + raise ValueError( + f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})" + ) + + # Handle special case for Hessian + if key == "hessian": + data = data.reshape(3 * natoms, 3 * natoms) + num_chunks, chunk_size = len(idx_map), 3 + idx_map_hess = np.arange( + num_chunks * chunk_size, dtype=int + ).reshape(num_chunks, chunk_size) + idx_map_hess = idx_map_hess[idx_map].flatten() + data = data[idx_map_hess, :] + data = data[:, idx_map_hess] + data = data.reshape(-1) + ndof = 3 * ndof * 3 * ndof # size of hessian is 3Natoms * 3Natoms + else: + # data should be 2D here: (natoms, ndof) + data = data.reshape([natoms, -1]) + data = data[idx_map, :] + + return np.float32(1.0), data + + except ValueError as err_message: + explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." + log.error(str(err_message)) + log.error(explanation) + raise ValueError(str(err_message) + ". " + explanation) from err_message + def _load_type(self, sys_path: DPPath) -> np.ndarray: atom_type = (sys_path / "type.raw").load_txt(ndmin=1).astype(np.int32) return atom_type From 63a40977a6b6ec05beae18812a1e6150f84702c9 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 20 Oct 2025 11:12:56 +0800 Subject: [PATCH 03/17] perf: use multithread to accelerate data loading --- deepmd/utils/data.py | 58 ++++++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index a895402cee..8d3c5166c4 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -3,6 +3,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import bisect import logging +from concurrent.futures import ( + ThreadPoolExecutor, + as_completed, +) from pathlib import ( Path, ) @@ -71,10 +75,7 @@ def __init__( raise FileNotFoundError(f"No {set_prefix}.* is found in {sys_path}") self.dirs.sort() # check mix_type format - error_format_msg = ( - "if one of the set is of mixed_type format, " - "then all of the sets in this system should be of mixed_type format!" - ) + error_format_msg = "if one of the set is of mixed_type format, then all of the sets in this system should be of mixed_type format!" self.mixed_type = self._check_mode(self.dirs[0]) for set_item in self.dirs[1:]: assert self._check_mode(set_item) == self.mixed_type, error_format_msg @@ -384,28 +385,37 @@ def get_single_frame(self, index: int) -> dict: local_idx = index - self.prefix_sum[set_idx] frame_data = {} - # 2. Load all non-reduced items first - # TODO: use async - for key, vv in self.data_dict.items(): - if vv["reduce"] is None: - frame_data["find_" + key], frame_data[key] = self._load_single_data( - set_dir, key, local_idx - ) + # 2. Concurrently load all non-reduced items + non_reduced_keys = [k for k, v in self.data_dict.items() if v["reduce"] is None] + reduced_keys = [k for k, v in self.data_dict.items() if v["reduce"] is not None] + # Use a thread pool to parallelize loading + if non_reduced_keys: + with ThreadPoolExecutor(max_workers=len(non_reduced_keys)) as executor: + future_to_key = { + executor.submit( + self._load_single_data, set_dir, key, local_idx + ): key + for key in non_reduced_keys + } + for future in as_completed(future_to_key): + key = future_to_key[future] + try: + frame_data["find_" + key], frame_data[key] = future.result() + except Exception as exc: + log.error(f"{key!r} generated an exception: {exc}") + raise # 3. Compute reduced items from already loaded data - # TODO: use async - for key, vv in self.data_dict.items(): - if vv["reduce"] is not None: - k_in = vv["reduce"] - ndof = vv["ndof"] - frame_data["find_" + key] = frame_data["find_" + k_in] - # Reshape to (natoms, ndof) and sum over atom axis - tmp_in = ( - frame_data[k_in] - .reshape(-1, ndof) - .astype(GLOBAL_ENER_FLOAT_PRECISION) - ) - frame_data[key] = np.sum(tmp_in, axis=0) + for key in reduced_keys: + vv = self.data_dict[key] + k_in = vv["reduce"] + ndof = vv["ndof"] + frame_data["find_" + key] = frame_data["find_" + k_in] + # Reshape to (natoms, ndof) and sum over atom axis + tmp_in = ( + frame_data[k_in].reshape(-1, ndof).astype(GLOBAL_ENER_FLOAT_PRECISION) + ) + frame_data[key] = np.sum(tmp_in, axis=0) # 4. Handle atom types (mixed or standard) if self.mixed_type: From ef876c98c93dbebe8f2fb9e755fb3911d71bbf93 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 20 Oct 2025 17:35:17 +0800 Subject: [PATCH 04/17] fix: handle different .npy file versions in data loading --- deepmd/utils/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 8d3c5166c4..11b02722db 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -505,9 +505,12 @@ def _get_memmap(self, path: DPPath) -> np.memmap: # Open the npy file to read its header and get shape/dtype with open(str(path), "rb") as f: version = np.lib.format.read_magic(f) - shape, fortran_order, dtype = np.lib.format._read_array_header( - f, version - ) + if version[0] == 1: + shape, fortran_order, dtype = np.lib.format.read_array_header_1_0(f) + elif version[0] in [2, 3]: + shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f) + else: + raise ValueError(f"Unsupported .npy file version: {version}") offset = f.tell() order = "F" if fortran_order else "C" # Create a read-only memmap and cache it From b460dc133325fe1a4cfed51ef5d5c4a3b61d0527 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 20 Oct 2025 18:44:01 +0800 Subject: [PATCH 05/17] feat: use lru cache --- deepmd/env.py | 21 ++++++++++++++++++ deepmd/utils/data.py | 53 +++++++++++++++++++++++++------------------- 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/deepmd/env.py b/deepmd/env.py index 2c1241a36b..a5f7068b84 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging import os +import platform from configparser import ( ConfigParser, ) @@ -16,6 +17,7 @@ "GLOBAL_CONFIG", "GLOBAL_ENER_FLOAT_PRECISION", "GLOBAL_NP_FLOAT_PRECISION", + "LRU_CACHE_SIZE", "SHARED_LIB_DIR", "SHARED_LIB_MODULE", "global_float_prec", @@ -47,6 +49,25 @@ "DP_INTERFACE_PREC." ) +# Dynamic calculation of cache size +DEFAULT_LRU_CACHE_SIZE = 888 +LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE + +if platform.system() != "Windows": + try: + import resource + + soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) + safe_buffer = 128 + if soft_limit > safe_buffer + DEFAULT_LRU_CACHE_SIZE: + LRU_CACHE_SIZE = soft_limit - safe_buffer + else: + LRU_CACHE_SIZE = soft_limit // 2 + except ImportError: + LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE +else: + LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE + def set_env_if_empty(key: str, value: str, verbose: bool = True) -> None: """Set environment variable only if it is empty. diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 11b02722db..4dad67aaaf 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import bisect +import functools import logging from concurrent.futures import ( ThreadPoolExecutor, @@ -20,6 +21,7 @@ from deepmd.env import ( GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, + LRU_CACHE_SIZE, ) from deepmd.utils import random as dp_random from deepmd.utils.path import ( @@ -500,29 +502,8 @@ def _get_natoms_2(self, ntypes: int) -> tuple[int, np.ndarray]: def _get_memmap(self, path: DPPath) -> np.memmap: """Get or create a memory-mapped object for a given npy file.""" - memmap_key = Path(str(path)).absolute() - if memmap_key not in self.memmap_cache: - # Open the npy file to read its header and get shape/dtype - with open(str(path), "rb") as f: - version = np.lib.format.read_magic(f) - if version[0] == 1: - shape, fortran_order, dtype = np.lib.format.read_array_header_1_0(f) - elif version[0] in [2, 3]: - shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f) - else: - raise ValueError(f"Unsupported .npy file version: {version}") - offset = f.tell() - order = "F" if fortran_order else "C" - # Create a read-only memmap and cache it - self.memmap_cache[memmap_key] = np.memmap( - str(path), - dtype=dtype, - mode="r", - shape=shape, - order=order, - offset=offset, - ) - return self.memmap_cache[memmap_key] + abs_path_str = str(Path(str(path)).absolute()) + return self._create_memmap(abs_path_str) def _get_subdata( self, data: dict[str, Any], idx: Optional[np.ndarray] = None @@ -962,6 +943,32 @@ def _check_pbc(self, sys_path: DPPath) -> bool: def _check_mode(self, set_path: DPPath) -> bool: return (set_path / "real_atom_types.npy").is_file() + @staticmethod + @functools.lru_cache(maxsize=LRU_CACHE_SIZE) + def _create_memmap(path_str: str) -> np.memmap: + """A cached helper function to create memmap objects. + Using lru_cache to limit the number of open file handles. + + Parameters + ---------- + path_str + The file path as a string. + """ + with open(path_str, "rb") as f: + version = np.lib.format.read_magic(f) + if version[0] == 1: + shape, fortran_order, dtype = np.lib.format.read_array_header_1_0(f) + elif version[0] in [2, 3]: + shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f) + else: + raise ValueError(f"Unsupported .npy file version: {version}") + offset = f.tell() + order = "F" if fortran_order else "C" + # Create a read-only memmap + return np.memmap( + path_str, dtype=dtype, mode="r", shape=shape, order=order, offset=offset + ) + class DataRequirementItem: """A class to store the data requirement for data systems. From 97ddb72313ac0fa252789f14316247e1668cec0d Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 16:10:12 +0800 Subject: [PATCH 06/17] bug fix & simplify --- deepmd/env.py | 10 +++++----- deepmd/pd/utils/dataset.py | 2 +- deepmd/utils/data.py | 17 ++++++++--------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/deepmd/env.py b/deepmd/env.py index a5f7068b84..d82a854e5c 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -50,8 +50,8 @@ ) # Dynamic calculation of cache size -DEFAULT_LRU_CACHE_SIZE = 888 -LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE +_default_lru_cache_size = 888 +LRU_CACHE_SIZE = _default_lru_cache_size if platform.system() != "Windows": try: @@ -59,14 +59,14 @@ soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) safe_buffer = 128 - if soft_limit > safe_buffer + DEFAULT_LRU_CACHE_SIZE: + if soft_limit > safe_buffer + _default_lru_cache_size: LRU_CACHE_SIZE = soft_limit - safe_buffer else: LRU_CACHE_SIZE = soft_limit // 2 except ImportError: - LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE + LRU_CACHE_SIZE = _default_lru_cache_size else: - LRU_CACHE_SIZE = DEFAULT_LRU_CACHE_SIZE + LRU_CACHE_SIZE = _default_lru_cache_size def set_env_if_empty(key: str, value: str, verbose: bool = True) -> None: diff --git a/deepmd/pd/utils/dataset.py b/deepmd/pd/utils/dataset.py index e2885e340e..1f0533d8fc 100644 --- a/deepmd/pd/utils/dataset.py +++ b/deepmd/pd/utils/dataset.py @@ -36,7 +36,7 @@ def __len__(self): def __getitem__(self, index): """Get a frame from the selected system.""" - b_data = self._data_system.get_frame_paddle(index) + b_data = self._data_system.get_item_paddle(index) b_data["natoms"] = self._natoms_vec return b_data diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 4dad67aaaf..2436ad3747 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -134,8 +134,6 @@ def __init__( self.shuffle_test = shuffle_test # set modifier self.modifier = modifier - # Add a cache for memory-mapped files - self.memmap_cache = {} # calculate prefix sum for get_item method frames_list = [self._get_nframes(item) for item in self.dirs] self.nframes = np.sum(frames_list) @@ -451,11 +449,12 @@ def get_single_frame(self, index: int) -> dict: # 6. Reshape atomic data to match expected format [natoms, ndof] for kk in self.data_dict.keys(): - if "find_" in kk: - pass - else: - if kk in frame_data and not self.data_dict[kk]["atomic"]: - frame_data[kk] = frame_data[kk].reshape(-1) + if ( + "find_" not in kk + and kk in frame_data + and not self.data_dict[kk]["atomic"] + ): + frame_data[kk] = frame_data[kk].reshape(-1) frame_data["atype"] = frame_data["type"] if not self.pbc: @@ -782,7 +781,7 @@ def _load_data( data = data.reshape([nframes, -1]) data = np.reshape(data, [nframes, ndof]) except ValueError as err_message: - explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." + explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." log.error(str(err_message)) log.error(explanation) raise ValueError(str(err_message) + ". " + explanation) from err_message @@ -904,7 +903,7 @@ def _load_single_data( return np.float32(1.0), data except ValueError as err_message: - explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." + explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." log.error(str(err_message)) log.error(explanation) raise ValueError(str(err_message) + ". " + explanation) from err_message From 2ce2ea4d9b9a1341ea4087f820157645383707ba Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 17:25:23 +0800 Subject: [PATCH 07/17] bug fix for non-atomic data --- deepmd/utils/data.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 2436ad3747..fbbf2602f9 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -853,7 +853,12 @@ def _load_single_data( # Branch 2: File exists, use memmap mmap_obj = self._get_memmap(path) # Slice the single frame and make an in-memory copy for modification - data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) + if mmap_obj.ndim == 0: + # Handle scalar data (0-dimensional array) + data = mmap_obj.copy().astype(dtype, copy=False) + else: + # Handle array data that can be indexed by frame + data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) try: if vv["atomic"]: @@ -900,7 +905,9 @@ def _load_single_data( data = data.reshape([natoms, -1]) data = data[idx_map, :] - return np.float32(1.0), data + # Handle non-atomic data + # For non-atomic data, reshape to (ndof,) shape + return np.float32(1.0), data.reshape([ndof]) except ValueError as err_message: explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." From e19efb0f4a2b5a86a415f9ec900460e5b133bfd4 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 17:48:54 +0800 Subject: [PATCH 08/17] solve logging issue & shape issue for zero/one dim data --- deepmd/utils/data.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index fbbf2602f9..247645f7d6 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -401,8 +401,8 @@ def get_single_frame(self, index: int) -> dict: key = future_to_key[future] try: frame_data["find_" + key], frame_data[key] = future.result() - except Exception as exc: - log.error(f"{key!r} generated an exception: {exc}") + except Exception: + log.exception("Key %r generated an exception", key) raise # 3. Compute reduced items from already loaded data @@ -854,10 +854,17 @@ def _load_single_data( mmap_obj = self._get_memmap(path) # Slice the single frame and make an in-memory copy for modification if mmap_obj.ndim == 0: - # Handle scalar data (0-dimensional array) + # Scalar array + data = mmap_obj.copy().astype(dtype, copy=False) + elif mmap_obj.ndim == 1: + # Single-frame file (shape: [ndof]); only frame_idx==0 is valid + if frame_idx != 0: + raise IndexError( + f"frame index {frame_idx} out of range for single-frame file: {path}" + ) data = mmap_obj.copy().astype(dtype, copy=False) else: - # Handle array data that can be indexed by frame + # Regular [nframes, ...] data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) try: @@ -899,21 +906,31 @@ def _load_single_data( data = data[idx_map_hess, :] data = data[:, idx_map_hess] data = data.reshape(-1) - ndof = 3 * ndof * 3 * ndof # size of hessian is 3Natoms * 3Natoms + # size of hessian is 3Natoms * 3Natoms + ndof = 3 * ndof * 3 * ndof else: # data should be 2D here: (natoms, ndof) data = data.reshape([natoms, -1]) data = data[idx_map, :] + # Atomic: return [natoms, ndof] or flattened hessian above + return np.float32(1.0), data - # Handle non-atomic data - # For non-atomic data, reshape to (ndof,) shape + # Non-atomic: return [ndof] return np.float32(1.0), data.reshape([ndof]) except ValueError as err_message: - explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." - log.error(str(err_message)) - log.error(explanation) - raise ValueError(str(err_message) + ". " + explanation) from err_message + explanation = ( + "This error may occur when your label mismatches its name, " + "e.g., global tensor stored in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." + ) + log.exception( + "Single-frame load failed for key=%s, set=%s, frame=%d. %s", + key, + set_dir, + frame_idx, + explanation, + ) + raise ValueError(f"{err_message}. {explanation}") from err_message def _load_type(self, sys_path: DPPath) -> np.ndarray: atom_type = (sys_path / "type.raw").load_txt(ndmin=1).astype(np.int32) From 291cbdbcf3add9916e7437fc885278f2069fd747 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 19:08:23 +0800 Subject: [PATCH 09/17] bug fix for corner case: single frame dataset --- deepmd/utils/data.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 247645f7d6..8db01ebc76 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -376,13 +376,15 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray: def get_single_frame(self, index: int) -> dict: """Orchestrates loading a single frame efficiently using memmap.""" + if index < 0 or index >= self.nframes: + raise IndexError(f"Frame index {index} out of range [0, {self.nframes})") # 1. Find the correct set directory and local frame index set_idx = bisect.bisect_right(self.prefix_sum, index) set_dir = self.dirs[set_idx] if not isinstance(set_dir, DPPath): set_dir = DPPath(set_dir) # Calculate local index within the set.* directory - local_idx = index - self.prefix_sum[set_idx] + local_idx = index - (0 if set_idx == 0 else self.prefix_sum[set_idx - 1]) frame_data = {} # 2. Concurrently load all non-reduced items @@ -854,18 +856,18 @@ def _load_single_data( mmap_obj = self._get_memmap(path) # Slice the single frame and make an in-memory copy for modification if mmap_obj.ndim == 0: - # Scalar array - data = mmap_obj.copy().astype(dtype, copy=False) + # case: single frame data && non-atomic + data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1) elif mmap_obj.ndim == 1: - # Single-frame file (shape: [ndof]); only frame_idx==0 is valid + # case: single frame data && atomic if frame_idx != 0: raise IndexError( f"frame index {frame_idx} out of range for single-frame file: {path}" ) - data = mmap_obj.copy().astype(dtype, copy=False) + data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1) else: - # Regular [nframes, ...] - data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) + # case: multi-frame data + data = mmap_obj[frame_idx].copy().astype(dtype, copy=False).reshape(1, -1) try: if vv["atomic"]: @@ -873,7 +875,7 @@ def _load_single_data( if vv["type_sel"] is not None: sel_mask = np.isin(self.atom_type, vv["type_sel"]) - if mmap_obj.shape[1] == natoms_sel * ndof: + if data.shape[1] == natoms_sel * ndof: if vv["output_natoms_for_type_sel"]: tmp = np.zeros([natoms, ndof], dtype=data.dtype) # sel_mask needs to be applied to the original atom layout @@ -882,7 +884,7 @@ def _load_single_data( else: # output is natoms_sel natoms = natoms_sel idx_map = idx_map_sel - elif mmap_obj.shape[1] == natoms * ndof: + elif data.shape[1] == natoms * ndof: data = data.reshape([natoms, ndof]) if vv["output_natoms_for_type_sel"]: pass @@ -892,7 +894,7 @@ def _load_single_data( natoms = natoms_sel else: # Shape mismatch error raise ValueError( - f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})" + f"The shape of the data {key} in {set_dir} has width {data.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})" ) # Handle special case for Hessian @@ -907,7 +909,7 @@ def _load_single_data( data = data[:, idx_map_hess] data = data.reshape(-1) # size of hessian is 3Natoms * 3Natoms - ndof = 3 * ndof * 3 * ndof + # ndof = 3 * ndof * 3 * ndof else: # data should be 2D here: (natoms, ndof) data = data.reshape([natoms, -1]) From aafccb2c5b6225b16f70045362a2d39d8f8473ad Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 21:14:35 +0800 Subject: [PATCH 10/17] use None to add one fake dim for single frame datasets --- deepmd/utils/data.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 8db01ebc76..97fe7becc5 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -854,20 +854,11 @@ def _load_single_data( # Branch 2: File exists, use memmap mmap_obj = self._get_memmap(path) + # corner case: single frame + if self._get_nframes(set_dir) == 1: + mmap_obj = mmap_obj[None, ...] # Slice the single frame and make an in-memory copy for modification - if mmap_obj.ndim == 0: - # case: single frame data && non-atomic - data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1) - elif mmap_obj.ndim == 1: - # case: single frame data && atomic - if frame_idx != 0: - raise IndexError( - f"frame index {frame_idx} out of range for single-frame file: {path}" - ) - data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1) - else: - # case: multi-frame data - data = mmap_obj[frame_idx].copy().astype(dtype, copy=False).reshape(1, -1) + data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) try: if vv["atomic"]: @@ -875,7 +866,7 @@ def _load_single_data( if vv["type_sel"] is not None: sel_mask = np.isin(self.atom_type, vv["type_sel"]) - if data.shape[1] == natoms_sel * ndof: + if mmap_obj.shape[1] == natoms_sel * ndof: if vv["output_natoms_for_type_sel"]: tmp = np.zeros([natoms, ndof], dtype=data.dtype) # sel_mask needs to be applied to the original atom layout @@ -884,7 +875,7 @@ def _load_single_data( else: # output is natoms_sel natoms = natoms_sel idx_map = idx_map_sel - elif data.shape[1] == natoms * ndof: + elif mmap_obj.shape[1] == natoms * ndof: data = data.reshape([natoms, ndof]) if vv["output_natoms_for_type_sel"]: pass @@ -894,7 +885,7 @@ def _load_single_data( natoms = natoms_sel else: # Shape mismatch error raise ValueError( - f"The shape of the data {key} in {set_dir} has width {data.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})" + f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})" ) # Handle special case for Hessian @@ -911,14 +902,13 @@ def _load_single_data( # size of hessian is 3Natoms * 3Natoms # ndof = 3 * ndof * 3 * ndof else: - # data should be 2D here: (natoms, ndof) + # data should be 2D here: [natoms, ndof] data = data.reshape([natoms, -1]) data = data[idx_map, :] - # Atomic: return [natoms, ndof] or flattened hessian above - return np.float32(1.0), data + # Atomic: return [natoms, ndof] or flattened hessian above # Non-atomic: return [ndof] - return np.float32(1.0), data.reshape([ndof]) + return np.float32(1.0), data except ValueError as err_message: explanation = ( From 0c8762567a26b77afeec980c6a6e63c1bf99dca7 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 25 Oct 2025 10:42:38 +0800 Subject: [PATCH 11/17] modify default array when file do not exits --- deepmd/utils/data.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 97fe7becc5..54a7f7a567 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -843,13 +843,13 @@ def _load_single_data( raise RuntimeError(f"{path} not found!") # Create a default array based on requirements - if ( - vv["atomic"] - and vv["type_sel"] is not None - and not vv["output_natoms_for_type_sel"] - ): - natoms = natoms_sel - data = np.full([natoms, ndof], vv["default"], dtype=dtype) + if vv["atomic"]: + if vv["type_sel"] is not None and not vv["output_natoms_for_type_sel"]: + natoms = natoms_sel + data = np.full([natoms, ndof], vv["default"], dtype=dtype) + else: + # For non-atomic data, shape should be [ndof] + data = np.full([ndof], vv["default"], dtype=dtype) return np.float32(0.0), data # Branch 2: File exists, use memmap From 47845a4f91b9b734b107fefd0a75e6ff8be0ddfc Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 25 Oct 2025 12:30:19 +0800 Subject: [PATCH 12/17] test:atom_polarizability also need to reshape to adapt current implementation --- deepmd/utils/data.py | 2 ++ source/tests/pt/test_loss_tensor.py | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 54a7f7a567..225770fc01 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -905,6 +905,8 @@ def _load_single_data( # data should be 2D here: [natoms, ndof] data = data.reshape([natoms, -1]) data = data[idx_map, :] + else: + data = data.reshape([ndof]) # Atomic: return [natoms, ndof] or flattened hessian above # Non-atomic: return [ndof] diff --git a/source/tests/pt/test_loss_tensor.py b/source/tests/pt/test_loss_tensor.py index 5802c0b775..933cfbb730 100644 --- a/source/tests/pt/test_loss_tensor.py +++ b/source/tests/pt/test_loss_tensor.py @@ -24,9 +24,11 @@ DataRequirementItem, ) -from ..seed import ( - GLOBAL_SEED, -) +# from ..seed import ( +# GLOBAL_SEED, +# ) + +GLOBAL_SEED = 7 CUR_DIR = os.path.dirname(__file__) @@ -57,7 +59,7 @@ def get_single_batch(dataset, index=None): if key in np_batch.keys(): np_batch[key] = np.expand_dims(np_batch[key], axis=0) pt_batch[key] = torch.as_tensor(np_batch[key], device=env.DEVICE) - if key in ["coord", "atom_dipole"]: + if key in ["coord", "atom_dipole", "atom_polarizability"]: np_batch[key] = np_batch[key].reshape(1, -1) np_batch["natoms"] = np_batch["natoms"][0] return np_batch, pt_batch From d877e5b78bfe6478a609b3e92816e5845f05a36b Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 25 Oct 2025 15:09:13 +0800 Subject: [PATCH 13/17] Use path + modification time as cache key to detect file changes --- deepmd/env.py | 2 +- deepmd/utils/data.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/deepmd/env.py b/deepmd/env.py index d82a854e5c..ccc09356ae 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -50,7 +50,7 @@ ) # Dynamic calculation of cache size -_default_lru_cache_size = 888 +_default_lru_cache_size = 512 LRU_CACHE_SIZE = _default_lru_cache_size if platform.system() != "Windows": diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 225770fc01..2d664227c9 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -502,9 +502,13 @@ def _get_natoms_2(self, ntypes: int) -> tuple[int, np.ndarray]: return natoms, natoms_vec def _get_memmap(self, path: DPPath) -> np.memmap: - """Get or create a memory-mapped object for a given npy file.""" - abs_path_str = str(Path(str(path)).absolute()) - return self._create_memmap(abs_path_str) + """Get or create a memory-mapped object for a given npy file. + Uses file path and modification time as cache keys to detect file changes + and invalidate cache when files are modified. + """ + abs_path = Path(str(path)).absolute() + file_mtime = abs_path.stat().st_mtime + return self._create_memmap(str(abs_path), str(file_mtime)) def _get_subdata( self, data: dict[str, Any], idx: Optional[np.ndarray] = None @@ -962,7 +966,7 @@ def _check_mode(self, set_path: DPPath) -> bool: @staticmethod @functools.lru_cache(maxsize=LRU_CACHE_SIZE) - def _create_memmap(path_str: str) -> np.memmap: + def _create_memmap(path_str: str, mtime_str: str) -> np.memmap: """A cached helper function to create memmap objects. Using lru_cache to limit the number of open file handles. @@ -970,6 +974,8 @@ def _create_memmap(path_str: str) -> np.memmap: ---------- path_str The file path as a string. + mtime_str + The modification time as a string, used for cache invalidation. """ with open(path_str, "rb") as f: version = np.lib.format.read_magic(f) From e6d4638e7a77b80ee4710079e6e3ef216d8da5e3 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 26 Oct 2025 14:14:39 +0800 Subject: [PATCH 14/17] perf: avoid repeated _get_nframes --- deepmd/utils/data.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 2d664227c9..80c13c97e0 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -385,6 +385,12 @@ def get_single_frame(self, index: int) -> dict: set_dir = DPPath(set_dir) # Calculate local index within the set.* directory local_idx = index - (0 if set_idx == 0 else self.prefix_sum[set_idx - 1]) + # Calculate the number of frames in this set to avoid redundant _get_nframes calls + set_nframes = ( + self.prefix_sum[set_idx] + if set_idx == 0 + else self.prefix_sum[set_idx] - self.prefix_sum[set_idx - 1] + ) frame_data = {} # 2. Concurrently load all non-reduced items @@ -395,7 +401,7 @@ def get_single_frame(self, index: int) -> dict: with ThreadPoolExecutor(max_workers=len(non_reduced_keys)) as executor: future_to_key = { executor.submit( - self._load_single_data, set_dir, key, local_idx + self._load_single_data, set_dir, key, local_idx, set_nframes ): key for key in non_reduced_keys } @@ -805,11 +811,22 @@ def _load_data( return np.float32(0.0), data def _load_single_data( - self, set_dir: DPPath, key: str, frame_idx: int + self, set_dir: DPPath, key: str, frame_idx: int, set_nframes: int ) -> tuple[np.float32, np.ndarray]: """ Loads and processes data for a SINGLE frame from a SINGLE key, fully replicating the logic from the original _load_data method. + + Parameters + ---------- + set_dir : DPPath + The directory path of the set + key : str + The key name of the data to load + frame_idx : int + The local frame index within the set + set_nframes : int + The total number of frames in this set (to avoid redundant _get_nframes calls) """ vv = self.data_dict[key] path = set_dir / (key + ".npy") @@ -859,7 +876,7 @@ def _load_single_data( # Branch 2: File exists, use memmap mmap_obj = self._get_memmap(path) # corner case: single frame - if self._get_nframes(set_dir) == 1: + if set_nframes == 1: mmap_obj = mmap_obj[None, ...] # Slice the single frame and make an in-memory copy for modification data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) From 40e7a1212ac48cb0844804d88da11c06df371cb4 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 26 Oct 2025 14:27:23 +0800 Subject: [PATCH 15/17] perf: vectorize type selection --- deepmd/utils/data.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 80c13c97e0..ccb0abda4b 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -102,10 +102,13 @@ def __init__( f"Elements {missing_elements} are not present in the provided `type_map`." ) if not self.mixed_type: - atom_type_ = [ - type_map.index(self.type_map[ii]) for ii in self.atom_type - ] - self.atom_type = np.array(atom_type_, dtype=np.int32) + # Use vectorized operation for better performance with large atom counts + # Create a mapping array where old_type_idx -> new_type_idx + max_old_type = max(self.atom_type) + 1 + type_mapping = np.zeros(max_old_type, dtype=np.int32) + for old_idx in range(len(self.type_map)): + type_mapping[old_idx] = type_map.index(self.type_map[old_idx]) + self.atom_type = type_mapping[self.atom_type].astype(np.int32) else: self.enforce_type_map = True sorter = np.argsort(type_map) @@ -489,11 +492,9 @@ def avg(self, key: str) -> float: return np.average(eners, axis=0) def _idx_map_sel(self, atom_type: np.ndarray, type_sel: list[int]) -> np.ndarray: - new_types = [] - for ii in atom_type: - if ii in type_sel: - new_types.append(ii) - new_types = np.array(new_types, dtype=int) + # Use vectorized operations instead of Python loop + sel_mask = np.isin(atom_type, type_sel) + new_types = atom_type[sel_mask] natoms = new_types.shape[0] idx = np.arange(natoms, dtype=np.int64) idx_map = np.lexsort((idx, new_types)) @@ -717,9 +718,9 @@ def _load_data( idx_map = self.idx_map # if type_sel, then revise natoms and idx_map if type_sel is not None: - natoms_sel = 0 - for jj in type_sel: - natoms_sel += np.sum(self.atom_type == jj) + # Use vectorized operations for better performance + sel_mask = np.isin(self.atom_type, type_sel) + natoms_sel = np.sum(sel_mask) idx_map_sel = self._idx_map_sel(self.atom_type, type_sel) else: natoms_sel = natoms @@ -747,7 +748,6 @@ def _load_data( tmp = np.zeros( [nframes, natoms, ndof_], dtype=data.dtype ) - sel_mask = np.isin(self.atom_type, type_sel) tmp[:, sel_mask] = data.reshape( [nframes, natoms_sel, ndof_] ) @@ -760,7 +760,6 @@ def _load_data( if output_natoms_for_type_sel: pass else: - sel_mask = np.isin(self.atom_type, type_sel) data = data.reshape([nframes, natoms, ndof_]) data = data[:, sel_mask] natoms = natoms_sel @@ -836,9 +835,9 @@ def _load_single_data( idx_map = self.idx_map # if type_sel, then revise natoms and idx_map if vv["type_sel"] is not None: - natoms_sel = 0 - for jj in vv["type_sel"]: - natoms_sel += np.sum(self.atom_type == jj) + # Use vectorized operations for better performance + sel_mask = np.isin(self.atom_type, vv["type_sel"]) + natoms_sel = np.sum(sel_mask) idx_map_sel = self._idx_map_sel(self.atom_type, vv["type_sel"]) else: natoms_sel = natoms @@ -885,8 +884,6 @@ def _load_single_data( if vv["atomic"]: # Handle type_sel logic if vv["type_sel"] is not None: - sel_mask = np.isin(self.atom_type, vv["type_sel"]) - if mmap_obj.shape[1] == natoms_sel * ndof: if vv["output_natoms_for_type_sel"]: tmp = np.zeros([natoms, ndof], dtype=data.dtype) From f0bbaa4b522e1dbe2a549212a731dad4d0ea1e5d Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 27 Oct 2025 11:50:57 +0800 Subject: [PATCH 16/17] typo fix --- deepmd/env.py | 2 -- source/tests/pt/test_loss_tensor.py | 8 +++----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/deepmd/env.py b/deepmd/env.py index ccc09356ae..bf1e794755 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -65,8 +65,6 @@ LRU_CACHE_SIZE = soft_limit // 2 except ImportError: LRU_CACHE_SIZE = _default_lru_cache_size -else: - LRU_CACHE_SIZE = _default_lru_cache_size def set_env_if_empty(key: str, value: str, verbose: bool = True) -> None: diff --git a/source/tests/pt/test_loss_tensor.py b/source/tests/pt/test_loss_tensor.py index 933cfbb730..67dcb568e1 100644 --- a/source/tests/pt/test_loss_tensor.py +++ b/source/tests/pt/test_loss_tensor.py @@ -24,11 +24,9 @@ DataRequirementItem, ) -# from ..seed import ( -# GLOBAL_SEED, -# ) - -GLOBAL_SEED = 7 +from ..seed import ( + GLOBAL_SEED, +) CUR_DIR = os.path.dirname(__file__) From 783281458776b2b3e7bbf83f83161e6b0dcc9ec5 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 27 Oct 2025 14:52:50 +0800 Subject: [PATCH 17/17] remove unnessary try...except --- deepmd/env.py | 19 ++++++++----------- deepmd/utils/data.py | 6 +----- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/deepmd/env.py b/deepmd/env.py index bf1e794755..7b29a338f1 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -54,17 +54,14 @@ LRU_CACHE_SIZE = _default_lru_cache_size if platform.system() != "Windows": - try: - import resource - - soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) - safe_buffer = 128 - if soft_limit > safe_buffer + _default_lru_cache_size: - LRU_CACHE_SIZE = soft_limit - safe_buffer - else: - LRU_CACHE_SIZE = soft_limit // 2 - except ImportError: - LRU_CACHE_SIZE = _default_lru_cache_size + import resource + + soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) + safe_buffer = 128 + if soft_limit > safe_buffer + _default_lru_cache_size: + LRU_CACHE_SIZE = soft_limit - safe_buffer + else: + LRU_CACHE_SIZE = soft_limit // 2 def set_env_if_empty(key: str, value: str, verbose: bool = True) -> None: diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index ccb0abda4b..b34d4d62df 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -410,11 +410,7 @@ def get_single_frame(self, index: int) -> dict: } for future in as_completed(future_to_key): key = future_to_key[future] - try: - frame_data["find_" + key], frame_data[key] = future.result() - except Exception: - log.exception("Key %r generated an exception", key) - raise + frame_data["find_" + key], frame_data[key] = future.result() # 3. Compute reduced items from already loaded data for key in reduced_keys: