Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions deepmd/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import platform
from configparser import (
ConfigParser,
)
Expand All @@ -16,6 +17,7 @@
"GLOBAL_CONFIG",
"GLOBAL_ENER_FLOAT_PRECISION",
"GLOBAL_NP_FLOAT_PRECISION",
"LRU_CACHE_SIZE",
"SHARED_LIB_DIR",
"SHARED_LIB_MODULE",
"global_float_prec",
Expand Down Expand Up @@ -47,6 +49,25 @@
"DP_INTERFACE_PREC."
)

# Dynamic calculation of cache size
_default_lru_cache_size = 512
LRU_CACHE_SIZE = _default_lru_cache_size

if platform.system() != "Windows":
try:
import resource

soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
safe_buffer = 128
if soft_limit > safe_buffer + _default_lru_cache_size:
LRU_CACHE_SIZE = soft_limit - safe_buffer
else:
LRU_CACHE_SIZE = soft_limit // 2
except ImportError:
LRU_CACHE_SIZE = _default_lru_cache_size
else:
LRU_CACHE_SIZE = _default_lru_cache_size


def set_env_if_empty(key: str, value: str, verbose: bool = True) -> None:
"""Set environment variable only if it is empty.
Expand Down
285 changes: 268 additions & 17 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

# SPDX-License-Identifier: LGPL-3.0-or-later
import bisect
import functools
import logging
from concurrent.futures import (
ThreadPoolExecutor,
as_completed,
)
from pathlib import (
Path,
)
from typing import (
Any,
Optional,
Expand All @@ -13,6 +21,7 @@
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
LRU_CACHE_SIZE,
)
from deepmd.utils import random as dp_random
from deepmd.utils.path import (
Expand Down Expand Up @@ -68,10 +77,7 @@ def __init__(
raise FileNotFoundError(f"No {set_prefix}.* is found in {sys_path}")
self.dirs.sort()
# check mix_type format
error_format_msg = (
"if one of the set is of mixed_type format, "
"then all of the sets in this system should be of mixed_type format!"
)
error_format_msg = "if one of the set is of mixed_type format, then all of the sets in this system should be of mixed_type format!"
self.mixed_type = self._check_mode(self.dirs[0])
for set_item in self.dirs[1:]:
assert self._check_mode(set_item) == self.mixed_type, error_format_msg
Expand Down Expand Up @@ -248,27 +254,18 @@ def get_item_torch(self, index: int) -> dict:
index
index of the frame
"""
i = bisect.bisect_right(self.prefix_sum, index)
frames = self._load_set(self.dirs[i])
frame = self._get_subdata(frames, index - self.prefix_sum[i])
frame = self.reformat_data_torch(frame)
frame["fid"] = index
return frame
return self.get_single_frame(index)

def get_item_paddle(self, index: int) -> dict:
"""Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets.
Same with PyTorch backend.

Parameters
----------
index
index of the frame
"""
i = bisect.bisect_right(self.prefix_sum, index)
frames = self._load_set(self.dirs[i])
frame = self._get_subdata(frames, index - self.prefix_sum[i])
frame = self.reformat_data_torch(frame)
frame["fid"] = index
return frame
return self.get_single_frame(index)

def get_batch(self, batch_size: int) -> dict:
"""Get a batch of data with `batch_size` frames. The frames are randomly picked from the data system.
Expand Down Expand Up @@ -377,6 +374,97 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray:
tmp = np.append(tmp, natoms_vec)
return tmp.astype(np.int32)

def get_single_frame(self, index: int) -> dict:
"""Orchestrates loading a single frame efficiently using memmap."""
if index < 0 or index >= self.nframes:
raise IndexError(f"Frame index {index} out of range [0, {self.nframes})")
# 1. Find the correct set directory and local frame index
set_idx = bisect.bisect_right(self.prefix_sum, index)
set_dir = self.dirs[set_idx]
if not isinstance(set_dir, DPPath):
set_dir = DPPath(set_dir)
# Calculate local index within the set.* directory
local_idx = index - (0 if set_idx == 0 else self.prefix_sum[set_idx - 1])

frame_data = {}
# 2. Concurrently load all non-reduced items
non_reduced_keys = [k for k, v in self.data_dict.items() if v["reduce"] is None]
reduced_keys = [k for k, v in self.data_dict.items() if v["reduce"] is not None]
# Use a thread pool to parallelize loading
if non_reduced_keys:
with ThreadPoolExecutor(max_workers=len(non_reduced_keys)) as executor:
future_to_key = {
executor.submit(
self._load_single_data, set_dir, key, local_idx
): key
for key in non_reduced_keys
}
for future in as_completed(future_to_key):
key = future_to_key[future]
try:
frame_data["find_" + key], frame_data[key] = future.result()
except Exception:
log.exception("Key %r generated an exception", key)
raise

# 3. Compute reduced items from already loaded data
for key in reduced_keys:
vv = self.data_dict[key]
k_in = vv["reduce"]
ndof = vv["ndof"]
frame_data["find_" + key] = frame_data["find_" + k_in]
# Reshape to (natoms, ndof) and sum over atom axis
tmp_in = (
frame_data[k_in].reshape(-1, ndof).astype(GLOBAL_ENER_FLOAT_PRECISION)
)
frame_data[key] = np.sum(tmp_in, axis=0)

# 4. Handle atom types (mixed or standard)
if self.mixed_type:
type_path = set_dir / "real_atom_types.npy"
mmap_types = self._get_memmap(type_path)
real_type = mmap_types[local_idx].copy().astype(np.int32)

if self.enforce_type_map:
try:
real_type = self.type_idx_map[real_type].astype(np.int32)
except IndexError as e:
raise IndexError(
f"some types in 'real_atom_types.npy' of set {set_dir} are not contained in {self.get_ntypes()} types!"
) from e

frame_data["type"] = real_type
ntypes = self.get_ntypes()
natoms = len(real_type)
# Use bincount for efficient counting of each type
natoms_vec = np.bincount(
real_type[real_type >= 0], minlength=ntypes
).astype(np.int32)
frame_data["real_natoms_vec"] = np.concatenate(
(np.array([natoms, natoms], dtype=np.int32), natoms_vec)
)
else:
frame_data["type"] = self.atom_type[self.idx_map]

# 5. Standardize keys
frame_data = {kk.replace("atomic", "atom"): vv for kk, vv in frame_data.items()}

# 6. Reshape atomic data to match expected format [natoms, ndof]
for kk in self.data_dict.keys():
if (
"find_" not in kk
and kk in frame_data
and not self.data_dict[kk]["atomic"]
):
frame_data[kk] = frame_data[kk].reshape(-1)
frame_data["atype"] = frame_data["type"]

if not self.pbc:
frame_data["box"] = None

frame_data["fid"] = index
return frame_data

def avg(self, key: str) -> float:
"""Return the average value of an item."""
if key not in self.data_dict.keys():
Expand Down Expand Up @@ -413,6 +501,15 @@ def _get_natoms_2(self, ntypes: int) -> tuple[int, np.ndarray]:
natoms_vec[ii] = np.count_nonzero(sample_type == ii)
return natoms, natoms_vec

def _get_memmap(self, path: DPPath) -> np.memmap:
"""Get or create a memory-mapped object for a given npy file.
Uses file path and modification time as cache keys to detect file changes
and invalidate cache when files are modified.
"""
abs_path = Path(str(path)).absolute()
file_mtime = abs_path.stat().st_mtime
return self._create_memmap(str(abs_path), str(file_mtime))

def _get_subdata(
self, data: dict[str, Any], idx: Optional[np.ndarray] = None
) -> dict[str, Any]:
Expand Down Expand Up @@ -690,7 +787,7 @@ def _load_data(
data = data.reshape([nframes, -1])
data = np.reshape(data, [nframes, ndof])
except ValueError as err_message:
explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."
explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."
log.error(str(err_message))
log.error(explanation)
raise ValueError(str(err_message) + ". " + explanation) from err_message
Expand All @@ -707,6 +804,132 @@ def _load_data(
data = np.repeat(data, repeat).reshape([nframes, -1])
return np.float32(0.0), data

def _load_single_data(
self, set_dir: DPPath, key: str, frame_idx: int
) -> tuple[np.float32, np.ndarray]:
"""
Loads and processes data for a SINGLE frame from a SINGLE key,
fully replicating the logic from the original _load_data method.
"""
vv = self.data_dict[key]
path = set_dir / (key + ".npy")

if vv["atomic"]:
natoms = self.natoms
idx_map = self.idx_map
# if type_sel, then revise natoms and idx_map
if vv["type_sel"] is not None:
natoms_sel = 0
for jj in vv["type_sel"]:
natoms_sel += np.sum(self.atom_type == jj)
idx_map_sel = self._idx_map_sel(self.atom_type, vv["type_sel"])
else:
natoms_sel = natoms
idx_map_sel = idx_map
else:
natoms = 1
natoms_sel = 0
idx_map_sel = None
ndof = vv["ndof"]

# Determine target data type from requirements
dtype = vv.get("dtype")
if dtype is None:
dtype = (
GLOBAL_ENER_FLOAT_PRECISION
if vv.get("high_prec")
else GLOBAL_NP_FLOAT_PRECISION
)

# Branch 1: File does not exist
if not path.is_file():
if vv.get("must"):
raise RuntimeError(f"{path} not found!")

# Create a default array based on requirements
if vv["atomic"]:
if vv["type_sel"] is not None and not vv["output_natoms_for_type_sel"]:
natoms = natoms_sel
data = np.full([natoms, ndof], vv["default"], dtype=dtype)
else:
# For non-atomic data, shape should be [ndof]
data = np.full([ndof], vv["default"], dtype=dtype)
return np.float32(0.0), data

# Branch 2: File exists, use memmap
mmap_obj = self._get_memmap(path)
# corner case: single frame
if self._get_nframes(set_dir) == 1:
mmap_obj = mmap_obj[None, ...]
# Slice the single frame and make an in-memory copy for modification
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False)

try:
if vv["atomic"]:
# Handle type_sel logic
if vv["type_sel"] is not None:
sel_mask = np.isin(self.atom_type, vv["type_sel"])

if mmap_obj.shape[1] == natoms_sel * ndof:
if vv["output_natoms_for_type_sel"]:
tmp = np.zeros([natoms, ndof], dtype=data.dtype)
# sel_mask needs to be applied to the original atom layout
tmp[sel_mask] = data.reshape([natoms_sel, ndof])
data = tmp
else: # output is natoms_sel
natoms = natoms_sel
idx_map = idx_map_sel
elif mmap_obj.shape[1] == natoms * ndof:
data = data.reshape([natoms, ndof])
if vv["output_natoms_for_type_sel"]:
pass
else:
data = data[sel_mask]
idx_map = idx_map_sel
natoms = natoms_sel
else: # Shape mismatch error
raise ValueError(
f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})"
)

# Handle special case for Hessian
if key == "hessian":
data = data.reshape(3 * natoms, 3 * natoms)
num_chunks, chunk_size = len(idx_map), 3
idx_map_hess = np.arange(
num_chunks * chunk_size, dtype=int
).reshape(num_chunks, chunk_size)
idx_map_hess = idx_map_hess[idx_map].flatten()
data = data[idx_map_hess, :]
data = data[:, idx_map_hess]
data = data.reshape(-1)
# size of hessian is 3Natoms * 3Natoms
# ndof = 3 * ndof * 3 * ndof
else:
# data should be 2D here: [natoms, ndof]
data = data.reshape([natoms, -1])
data = data[idx_map, :]
else:
data = data.reshape([ndof])

# Atomic: return [natoms, ndof] or flattened hessian above
# Non-atomic: return [ndof]
return np.float32(1.0), data

except ValueError as err_message:
explanation = (
"This error may occur when your label mismatches its name, "
"e.g., global tensor stored in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."
)
log.exception(
"Single-frame load failed for key=%s, set=%s, frame=%d. %s",
key,
set_dir,
frame_idx,
explanation,
)
raise ValueError(f"{err_message}. {explanation}") from err_message

def _load_type(self, sys_path: DPPath) -> np.ndarray:
atom_type = (sys_path / "type.raw").load_txt(ndmin=1).astype(np.int32)
return atom_type
Expand Down Expand Up @@ -741,6 +964,34 @@ def _check_pbc(self, sys_path: DPPath) -> bool:
def _check_mode(self, set_path: DPPath) -> bool:
return (set_path / "real_atom_types.npy").is_file()

@staticmethod
@functools.lru_cache(maxsize=LRU_CACHE_SIZE)
def _create_memmap(path_str: str, mtime_str: str) -> np.memmap:
"""A cached helper function to create memmap objects.
Using lru_cache to limit the number of open file handles.

Parameters
----------
path_str
The file path as a string.
mtime_str
The modification time as a string, used for cache invalidation.
"""
with open(path_str, "rb") as f:
version = np.lib.format.read_magic(f)
if version[0] == 1:
shape, fortran_order, dtype = np.lib.format.read_array_header_1_0(f)
elif version[0] in [2, 3]:
shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f)
else:
raise ValueError(f"Unsupported .npy file version: {version}")
offset = f.tell()
order = "F" if fortran_order else "C"
# Create a read-only memmap
return np.memmap(
path_str, dtype=dtype, mode="r", shape=shape, order=order, offset=offset
)


class DataRequirementItem:
"""A class to store the data requirement for data systems.
Expand Down
Loading