Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions deepmd/pd/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Thread,
)

import h5py
import numpy as np
import paddle
import paddle.distributed as dist
Expand Down Expand Up @@ -90,9 +89,12 @@ def __init__(
):
if seed is not None:
setup_seed(seed)
if isinstance(systems, str):
with h5py.File(systems) as file:
systems = [os.path.join(systems, item) for item in file.keys()]
# Use process_systems to handle HDF5 expansion and other system processing
from deepmd.utils.data_system import (
process_systems,
)

systems = process_systems(systems)

self.systems: list[DeepmdDataSetForLoader] = []
if len(systems) >= 100:
Expand Down
11 changes: 6 additions & 5 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
from multiprocessing.dummy import (
Pool,
)
Expand All @@ -10,7 +9,6 @@
Union,
)

import h5py
import numpy as np
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -88,9 +86,12 @@ def __init__(
) -> None:
if seed is not None:
setup_seed(seed)
if isinstance(systems, str):
with h5py.File(systems) as file:
systems = [os.path.join(systems, item) for item in file.keys()]
# Use process_systems to handle HDF5 expansion and other system processing
from deepmd.utils.data_system import (
process_systems,
)

systems = process_systems(systems)

def construct_dataset(system: str) -> DeepmdDataSetForLoader:
return DeepmdDataSetForLoader(
Expand Down
144 changes: 143 additions & 1 deletion deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import collections
import itertools
import logging
import os
import warnings
from functools import (
cached_property,
Expand All @@ -11,6 +13,7 @@
Union,
)

import h5py
import numpy as np

import deepmd.utils.random as dp_random
Expand Down Expand Up @@ -784,12 +787,63 @@ def prob_sys_size_ext(keywords: str, nsystems: int, nbatch: int) -> list[float]:
return sys_probs


def _process_single_system(system: str) -> list[str]:
"""Process a single system string and return list of systems.

Parameters
----------
system : str
A single system path

Returns
-------
list[str]
List of processed system paths
"""
# Check if this is an HDF5 file without explicit system specification
if _is_hdf5_file(system) and "#" not in system:
try:
with h5py.File(system, "r") as file:
# Check if this looks like a single system (has type.raw and set.* groups at root)
has_type_raw = "type.raw" in file
has_sets = any(key.startswith("set.") for key in file.keys())

if has_type_raw and has_sets:
# This is a single system HDF5 file, use standard HDF5 format
return [f"{system}#/"]

# Look for system-like groups and expand them
expanded = []
for key in file.keys():
if isinstance(file[key], h5py.Group):
# Check if this group looks like a system
group = file[key]
group_has_type = "type.raw" in group
group_has_sets = any(
subkey.startswith("set.") for subkey in group.keys()
)
if group_has_type and group_has_sets:
expanded.append(f"{system}#/{key}")

# If we found system-like groups, return them; otherwise treat as regular system
return expanded if expanded else [system]

except OSError as e:
log.warning(f"Could not read HDF5 file {system}: {e}")
# If we can't read as HDF5, treat as regular system
return [system]
else:
# Regular system or HDF5 with explicit system specification
return [system]


def process_systems(
systems: Union[str, list[str]], patterns: Optional[list[str]] = None
) -> list[str]:
"""Process the user-input systems.

If it is a single directory, search for all the systems in the directory.
If it's a list, handle HDF5 files by expanding their internal systems.
Check if the systems are valid.

Parameters
Expand All @@ -810,10 +864,98 @@ def process_systems(
else:
systems = rglob_sys_str(systems, patterns)
elif isinstance(systems, list):
systems = systems.copy()
# Process each system individually and flatten results
systems = list(
itertools.chain.from_iterable(
_process_single_system(system) for system in systems
)
)
return systems


def _is_hdf5_file(path: str) -> bool:
"""Check if a path points to an HDF5 file.

Parameters
----------
path : str
Path to check

Returns
-------
bool
True if the path is an HDF5 file
"""
# Extract the actual file path (before any # separator for HDF5 internal paths)
file_path = path.split("#")[0]
return os.path.isfile(file_path) and (
file_path.endswith((".h5", ".hdf5")) or _is_hdf5_format(file_path)
)


def _is_hdf5_multisystem(file_path: str) -> bool:
"""Check if an HDF5 file contains multiple systems vs being a single system.

Parameters
----------
file_path : str
Path to the HDF5 file

Returns
-------
bool
True if the file contains multiple systems, False if it's a single system
"""
try:
with h5py.File(file_path, "r") as f:
# Check if this looks like a single system (has type.raw and set.* groups)
has_type_raw = "type.raw" in f
has_sets = any(key.startswith("set.") for key in f.keys())

if has_type_raw and has_sets:
# This looks like a single system
return False

# Check if it contains multiple groups that could be systems
system_groups = []
for key in f.keys():
if isinstance(f[key], h5py.Group):
group = f[key]
# Check if this group looks like a system (has type.raw and sets)
group_has_type = "type.raw" in group
group_has_sets = any(
subkey.startswith("set.") for subkey in group.keys()
)
if group_has_type and group_has_sets:
system_groups.append(key)

# If we found multiple system-like groups, it's a multisystem file
return len(system_groups) > 1

except OSError:
return False


def _is_hdf5_format(file_path: str) -> bool:
"""Check if a file is in HDF5 format by trying to open it.

Parameters
----------
file_path : str
Path to the file

Returns
-------
bool
True if the file is in HDF5 format
"""
try:
with h5py.File(file_path, "r"):
return True
except OSError:
return False


def get_data(
jdata: dict[str, Any],
rcut: float,
Expand Down
Loading
Loading