Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions fast_llm/data/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,16 @@
Import these submodules to ensure classes are added to the dynamic class registry.
"""

from fast_llm.data.dataset.config import ( # isort: skip
BlendedDatasetConfig,
ConcatenatedDatasetConfig,
DatasetSliceConfig,
MemmapDatasetConfig,
SampledDatasetUpdateConfig,
)
from fast_llm.data.dataset.gpt.config import ( # isort: skip
GPTDatasetFromFileConfig,
GPTFimSampledDatasetConfig,
GPTRandomDatasetConfig,
)
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
41 changes: 0 additions & 41 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
import enum
import pathlib
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.tokenizer import Tokenizer


class MultiprocessingContext(str, enum.Enum):
Expand All @@ -15,36 +7,3 @@ class MultiprocessingContext(str, enum.Enum):
fork = "fork"
# Safe but much slower.
spawn = "spawn"


TokenizerFromFile = "TokenizerFromFile"


@config_class()
class TokenizerConfig(Config):
"""
Configuration for the tokenizer.
The tokenizer is needed for FIM and dataset preparation.
"""

format: str = Field(
default="TokenizerFromFile",
desc="Unused.",
hint=FieldHint.deprecated,
valid=check_field(Assert.eq, TokenizerFromFile),
)
path: pathlib.Path = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
)
bos_token: str | None = Field(
default=None,
desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.",
hint=FieldHint.core,
)

def get_tokenizer(self) -> "Tokenizer":
from fast_llm.data.tokenizer import Tokenizer

return Tokenizer(self)
33 changes: 32 additions & 1 deletion fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import functools
import itertools
import logging
import math
import pathlib
import typing
Expand All @@ -13,8 +14,11 @@

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.engine.distributed.distributed import Distributed

logger = logging.getLogger(__name__)


class ShufflingType(str, enum.Enum):
# Shuffle all epochs together. Not extendable.
Expand Down Expand Up @@ -105,7 +109,6 @@ class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]):
"""

def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
# TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`)
raise NotImplementedError()


Expand Down Expand Up @@ -266,3 +269,31 @@ def build_and_sample(
self.weights,
sampling,
)


@config_class(dynamic_type={SampledDatasetConfig: "memmap"})
class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
path: pathlib.Path = Field(
default=None,
desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.",
hint=FieldHint.core,
)

def build(self) -> "IndexedDataset[SampleType]":
name = str(self.path).replace("/", "__")
if self.path.is_file():
from fast_llm.data.dataset.memmap import MemmapDataset

return MemmapDataset[SampleType](name, self.path)
elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file():
logger.warning(
"Using the legacy memmap dataset format."
" This format is deprecated and will be removed in a future release."
" Please recreate the dataset in the new memmap format."
)
from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset

return LegacyMemmapDataset[SampleType](name, self.path)
else:
raise FileNotFoundError(self.path)
44 changes: 6 additions & 38 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,15 @@
import yaml

from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.data.config import TokenizerConfig
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.dataset.config import (
IndexedDatasetConfig,
SamplableDatasetConfig,
SampledDatasetConfig,
SamplingData,
SamplingParameters,
)
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters
from fast_llm.data.preprocessing.tokenizer import TokenizerConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.fim import GPTFimDataset
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
from fast_llm.data.sample.language_model import LanguageModelSample


@dataclasses.dataclass(kw_only=True)
Expand All @@ -30,7 +23,9 @@ class GPTSamplingParameters(SamplingParameters):
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
"""

vocab_size: int
# TODO: Only used for random dataset. Remove? Or use as safety check?
vocab_size: int | None = None
# TODO: ====== Get these to memmap dataset (currently ignored) ======
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False

Expand Down Expand Up @@ -60,33 +55,6 @@ def build(self) -> "GPTRandomDataset[SampleType]":
return GPTRandomDataset[SampleType](self.name)


@config_class(dynamic_type={SampledDatasetConfig: "memmap"})
class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
path: pathlib.Path = Field(
default=None,
desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.",
hint=FieldHint.core,
)
num_documents: int | None = Field(
default=None,
desc="Expected number of documents in the dataset.",
hint=FieldHint.optional,
)
num_tokens: int | None = Field(
default=None,
desc="Expected number of tokens in the dataset.",
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset[SampleType]":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset[SampleType](
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens
)


@config_class(dynamic_type={SampledDatasetConfig: "file"})
class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
Expand Down
8 changes: 5 additions & 3 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.sample.token import TokenSample
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.distributed.config import MAX_SEED


Expand Down Expand Up @@ -168,9 +169,10 @@ def _fim_permute_sequence(
middle = contents[boundaries[0] : boundaries[1]]
suffix = contents[boundaries[1] :]

prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=sequence.dtype)
middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=sequence.dtype)
suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=sequence.dtype)
data_type = DataType.from_numpy(sequence.dtype)
prefix = self._tokenizer.tokenize(prefix, end=False, data_type=data_type).numpy()
middle = self._tokenizer.tokenize(middle, begin=False, end=False, data_type=data_type).numpy()
suffix = self._tokenizer.tokenize(suffix, begin=False, data_type=data_type).numpy()

# here we truncate each given segment to fit the same length as it was before
# A consequence is that we never reach the end of a file?
Expand Down
Loading