Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/megatron/energon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
MetadatasetV2,
load_dataset,
prepare_metadataset,
traverse_metadataset,
)
from megatron.energon.savable_loader import SavableDataLoader
from megatron.energon.source_info import SourceInfo
Expand Down Expand Up @@ -191,6 +192,7 @@
"SampleDecoder",
"SavableDataLoader",
"SavableDataset",
"traverse_metadataset",
"ShuffleBufferDataset",
"SimilarityInterleavedSample",
"SimilarityInterleavedWebdataset",
Expand Down
91 changes: 81 additions & 10 deletions src/megatron/energon/epathlib/epath.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import os
import re
import shutil
from pathlib import Path as PathlibPath
from pathlib import PurePosixPath
from typing import BinaryIO, Generator, Literal, Optional, TextIO, Tuple, Union, overload
from typing import (
BinaryIO,
Generator,
Literal,
Optional,
TextIO,
Tuple,
Union,
overload,
)

import multistorageclient as msc

Expand All @@ -31,7 +41,9 @@ class EPath:
for more information.
"""

# The path without the protocol. Can also be in S3 for example
__slots__ = ("internal_path", "profile", "fs")

# The path without the protocol/profile. Can also be in S3 for example
internal_path: PurePosixPath
# The profile used to access the file system
profile: str
Expand All @@ -52,11 +64,29 @@ def __init__(
profile = DEFAULT_PROFILE_NAME
else:
protocol, profile, path = self._split_protocol(initial_path)
if protocol is None or protocol == "file":
if protocol is None:
# Just a local absolute/relative path
assert profile is None
profile = DEFAULT_PROFILE_NAME
path = str(PathlibPath(path).absolute())
elif protocol == "file":
# A file:// path, e.g. file:///home/user/file.txt (absolute) or file://file.txt (relative)
assert profile is not None
path = profile + "/" + path
profile = DEFAULT_PROFILE_NAME
path = str(PathlibPath(path).absolute())
elif protocol == "rclone":
warn_deprecated("rclone:// protocol is deprecated. Use msc:// instead.")
elif protocol == "dss":
# Profile corresponds to the dataset name and version
assert profile is not None
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
self.fs = NVDATASET_CACHE_DIR.fs
self.profile = "dss"
self.internal_path = self._resolve(f"/{profile}/{path}")
return
else:
assert protocol == "msc", f"Unknown protocol: {protocol}"
if not path.startswith("/"):
Expand All @@ -77,7 +107,13 @@ def __getstate__(self) -> dict:
def __setstate__(self, state: dict) -> None:
self.internal_path = state["internal_path"]
self.profile = state["profile"]
self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}")
if self.profile == "dss":
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
self.fs = NVDATASET_CACHE_DIR.fs
else:
self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}")

@staticmethod
def _resolve(path: Union[str, PurePosixPath]) -> PurePosixPath:
Expand All @@ -103,16 +139,27 @@ def _resolve(path: Union[str, PurePosixPath]) -> PurePosixPath:

@staticmethod
def _split_protocol(path: str) -> Tuple[Optional[str], Optional[str], str]:
regex = re.compile(r"^(?P<protocol>[a-z]+)://(?P<profile>[^/]+?)/(?P<path>.+)$")
regex = re.compile(r"^(?P<protocol>[a-z]+)://(?P<profile>[^/]+?)(?:/(?P<path>.*))?$")
m = regex.match(path)
if m is None:
return None, None, path
return m.group("protocol"), m.group("profile"), m.group("path")
inner_path = m.group("path")
if not inner_path:
inner_path = ""
return m.group("protocol"), m.group("profile"), inner_path

@property
def _internal_str_path(self) -> str:
"""Return the path as used inside the file system, without the protocol and fs part."""
return str(self.internal_path)
"""Return the path as used inside the file system, without the protocol and fs part.
This is for usage with `self.fs` functions."""
if self.profile == "dss":
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
# The internal path is relative to the NVDATASET_CACHE_DIR (i.e. strip the leading /, then concat with /)
return NVDATASET_CACHE_DIR._internal_str_path + str(self.internal_path)
else:
return str(self.internal_path)

@overload
def open(
Expand Down Expand Up @@ -189,13 +236,22 @@ def parent(self) -> "EPath":

@property
def url(self) -> str:
if self.is_local():
if self.profile == DEFAULT_PROFILE_NAME:
return self._internal_str_path
int_path_str = str(self.internal_path)
if self.profile == "dss":
if int_path_str.startswith("/"):
int_path_str = int_path_str[1:]
return f"dss://{int_path_str}"
return f"msc://{self.profile}{int_path_str}"

def is_local(self) -> bool:
return self.profile == DEFAULT_PROFILE_NAME
if self.profile == "dss":
# For now, a DSS path is always considered local.
# Note that this does not mean it exists on the local filesystem.
return True
else:
return self.profile == DEFAULT_PROFILE_NAME

def local_path(self) -> PathlibPath:
if not self.is_local():
Expand Down Expand Up @@ -245,6 +301,15 @@ def relative_to(self, other: "EPath") -> str:

return str(self.internal_path.relative_to(other.internal_path))

@property
def display_name(self) -> str:
if self.profile == "dss":
# Use the ds name for DSS paths
# E.g. from /charts@v0/something return parts[1] i.e. "charts@v0"
return self.internal_path.parts[1]
# Use the name for other paths
return self.name

def __truediv__(self, other: Union[str, "EPath"]) -> "EPath":
if isinstance(other, EPath):
# Always absolute
Expand Down Expand Up @@ -276,3 +341,9 @@ def __eq__(self, other: object) -> bool:
and self.internal_path == other.internal_path
and self.profile == other.profile
)


if "NVDATASET_CACHE_DIR" in os.environ:
NVDATASET_CACHE_DIR = EPath(os.environ["NVDATASET_CACHE_DIR"])
else:
NVDATASET_CACHE_DIR = None
1 change: 1 addition & 0 deletions src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__"
self.path = path
self.paths = [path]
self.name = path.display_name
self.training = training
self.worker_config = worker_config
self.shuffle_over_epochs = shuffle_over_epochs
Expand Down
1 change: 1 addition & 0 deletions src/megatron/energon/flavors/webdataset/base_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
)
self.path = path
self.paths = [path]
self.name = path.display_name
self.shards = wds_meta.shards
self.sample_excludes = wds_meta.sample_excludes
self.split_part_files = wds_meta.split_part_files
Expand Down
7 changes: 6 additions & 1 deletion src/megatron/energon/metadataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
# SPDX-License-Identifier: BSD-3-Clause

from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.loader import load_dataset, prepare_metadataset
from megatron.energon.metadataset.loader import (
load_dataset,
prepare_metadataset,
traverse_metadataset,
)
from megatron.energon.metadataset.loader_interface import DatasetLoaderInterface
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.metadataset.metadataset_v2 import MetadatasetV2
Expand All @@ -14,4 +18,5 @@
"Metadataset",
"MetadatasetV2",
"prepare_metadataset",
"traverse_metadataset",
]
45 changes: 44 additions & 1 deletion src/megatron/energon/metadataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.loader_interface import DatasetLoaderInterface
from megatron.energon.metadataset.loader_interface import (
DatasetLoaderInterface,
TraversedDatasetReference,
)
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.typed_converter import JsonParser

Expand Down Expand Up @@ -46,6 +49,46 @@ def load_dataset(
raise ValueError(f"Invalid dataset at {path}")


def traverse_metadataset(
path: Union[str, EPath, Path],
*,
split_part: str,
**kwargs,
) -> list[TraversedDatasetReference]:
"""Traverse one metadataset split and return flattened leaf dataset references.

This is the main public entrypoint for traversal-only inspection of a metadataset. It loads
the root metadataset configuration, resolves nested metadatasets recursively, and returns the
final leaf dataset references without constructing the intermediate scanned/traversed loader
tree.

Args:
path: Path to the metadataset YAML file to traverse.
split_part: Split to traverse, such as `\"train\"`, `\"val\"`, or `\"test\"`.
**kwargs: Additional keyword arguments forwarded to `load_config()` while loading the root
metadataset object.

Returns:
A flattened list of `TraversedDatasetReference` values describing the reachable leaf
datasets for the requested split.

Raises:
AssertionError: If `path` does not point to a metadataset.
"""

path = EPath(path)
ds_type = get_dataset_type(path)
assert ds_type == EnergonDatasetType.METADATASET, (
f"traverse_metadataset only supports metadatasets, got {ds_type} at {path}"
)
mds = load_config(
path,
default_type=Metadataset,
default_kwargs=dict(path=path, **kwargs),
)
return mds.traverse(split_part=split_part)


class MockJsonParser(JsonParser):
"""Json Parser, which translates unknown objects to a mock class."""

Expand Down
44 changes: 44 additions & 0 deletions src/megatron/energon/metadataset/loader_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

Expand Down Expand Up @@ -38,6 +39,23 @@ class LoadedDatasetList:
blend_mode: DatasetBlendMode = DatasetBlendMode.NONE


@dataclass
class TraversedDatasetReference:
"""Flattened leaf dataset reference produced by metadataset traversal.

Attributes:
path: Resolved path to the referenced leaf dataset.
split_part: Effective split part to use when loading the leaf dataset.
aux: Resolved auxiliary dataset or filesystem references keyed by auxiliary name.
subflavors: Effective subflavors implied by the traversed metadataset hierarchy.
"""

path: EPath
split_part: str
aux: dict[str, EPath]
subflavors: dict[str, Any]


class DatasetLoaderInterface(ABC):
"""General interface for a dataset loader."""

Expand All @@ -46,6 +64,32 @@ def post_initialize(self, mds_path: Optional[EPath] = None):
"""Called to finally initialize the dataset."""
...

def traverse(
Comment thread
philipp-fischer marked this conversation as resolved.
self,
mds_path: Optional[EPath] = None,
*,
split_part: Union[Literal["train", "val", "test"], str],
_subflavors: Optional[Dict[str, Any]] = None,
) -> List[TraversedDatasetReference]:
"""Traverse a metadataset subtree and collect flattened leaf dataset references.

This method is the traversal-side counterpart to `get_datasets()`. Instead of
instantiating dataset loaders for leaf datasets, it walks the hierarchy, resolves nested
metadataset references, and returns the final leaf dataset references for a single split.

Args:
mds_path: Parent metadataset path used internally to resolve relative dataset and
auxiliary paths. Must be set for nested references and inner traversal nodes;
use None only for top-level metadatasets.
split_part: Split to traverse, such as `\"train\"`, `\"val\"`, or `\"test\"`. Nested
references may override this with their own configured split.

Returns:
A flattened list of `TraversedDatasetReference` values for all leaf datasets reached
during the traversal.
"""
raise NotImplementedError(f"{type(self).__name__} does not implement traverse()")

@abstractmethod
def get_datasets(
self,
Expand Down
Loading
Loading