diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index c44e538fa..22ab3d731 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -2,4 +2,16 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ +from fast_llm.data.dataset.config import ( # isort: skip + BlendedDatasetConfig, + ConcatenatedDatasetConfig, + DatasetSliceConfig, + MemmapDatasetConfig, + SampledDatasetUpdateConfig, +) +from fast_llm.data.dataset.gpt.config import ( # isort: skip + GPTDatasetFromFileConfig, + GPTFimSampledDatasetConfig, + GPTRandomDatasetConfig, +) from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 633367c80..78bc20636 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,12 +1,4 @@ import enum -import pathlib -import typing - -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - from fast_llm.data.tokenizer import Tokenizer class MultiprocessingContext(str, enum.Enum): @@ -15,36 +7,3 @@ class MultiprocessingContext(str, enum.Enum): fork = "fork" # Safe but much slower. spawn = "spawn" - - -TokenizerFromFile = "TokenizerFromFile" - - -@config_class() -class TokenizerConfig(Config): - """ - Configuration for the tokenizer. - The tokenizer is needed for FIM and dataset preparation. - """ - - format: str = Field( - default="TokenizerFromFile", - desc="Unused.", - hint=FieldHint.deprecated, - valid=check_field(Assert.eq, TokenizerFromFile), - ) - path: pathlib.Path = Field( - default=None, - desc="Path to the tokenizer file.", - hint=FieldHint.core, - ) - bos_token: str | None = Field( - default=None, - desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", - hint=FieldHint.core, - ) - - def get_tokenizer(self) -> "Tokenizer": - from fast_llm.data.tokenizer import Tokenizer - - return Tokenizer(self) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 20e40b66e..7611b4a31 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -2,6 +2,7 @@ import enum import functools import itertools +import logging import math import pathlib import typing @@ -13,8 +14,11 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class ShufflingType(str, enum.Enum): # Shuffle all epochs together. Not extendable. @@ -105,7 +109,6 @@ class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: - # TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`) raise NotImplementedError() @@ -266,3 +269,31 @@ def build_and_sample( self.weights, sampling, ) + + +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): + _abstract: typing.ClassVar[bool] = False + path: pathlib.Path = Field( + default=None, + desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", + hint=FieldHint.core, + ) + + def build(self) -> "IndexedDataset[SampleType]": + name = str(self.path).replace("/", "__") + if self.path.is_file(): + from fast_llm.data.dataset.memmap import MemmapDataset + + return MemmapDataset[SampleType](name, self.path) + elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file(): + logger.warning( + "Using the legacy memmap dataset format." + " This format is deprecated and will be removed in a future release." + " Please recreate the dataset in the new memmap format." + ) + from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset + + return LegacyMemmapDataset[SampleType](name, self.path) + else: + raise FileNotFoundError(self.path) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 15f54ec80..8dd4098a3 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -6,22 +6,15 @@ import yaml from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.data.dataset.config import ( - IndexedDatasetConfig, - SamplableDatasetConfig, - SampledDatasetConfig, - SamplingData, - SamplingParameters, -) -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset + from fast_llm.data.sample.language_model import LanguageModelSample @dataclasses.dataclass(kw_only=True) @@ -30,7 +23,9 @@ class GPTSamplingParameters(SamplingParameters): Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - vocab_size: int + # TODO: Only used for random dataset. Remove? Or use as safety check? + vocab_size: int | None = None + # TODO: ====== Get these to memmap dataset (currently ignored) ====== use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False @@ -60,33 +55,6 @@ def build(self) -> "GPTRandomDataset[SampleType]": return GPTRandomDataset[SampleType](self.name) -@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): - _abstract: typing.ClassVar[bool] = False - path: pathlib.Path = Field( - default=None, - desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", - hint=FieldHint.core, - ) - num_documents: int | None = Field( - default=None, - desc="Expected number of documents in the dataset.", - hint=FieldHint.optional, - ) - num_tokens: int | None = Field( - default=None, - desc="Expected number of tokens in the dataset.", - hint=FieldHint.optional, - ) - - def build(self) -> "GPTMemmapDataset[SampleType]": - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - - return GPTMemmapDataset[SampleType]( - str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens - ) - - @config_class(dynamic_type={SampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 1fde74530..d36384ee5 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -5,6 +5,7 @@ from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import MAX_SEED @@ -168,9 +169,10 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=sequence.dtype) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=sequence.dtype) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=sequence.dtype) + data_type = DataType.from_numpy(sequence.dtype) + prefix = self._tokenizer.tokenize(prefix, end=False, data_type=data_type).numpy() + middle = self._tokenizer.tokenize(middle, begin=False, end=False, data_type=data_type).numpy() + suffix = self._tokenizer.tokenize(suffix, begin=False, data_type=data_type).numpy() # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py similarity index 61% rename from fast_llm/data/dataset/gpt/memmap.py rename to fast_llm/data/dataset/gpt/legacy_memmap.py index 06d8d7acc..2a23e378b 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -1,21 +1,31 @@ import pathlib import struct -import typing import numpy as np import torch from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div +MEMMAP_DTYPES = { + 1: DataType.uint8, + 2: DataType.int8, + 3: DataType.int16, + 4: DataType.int32, + 5: DataType.int64, + 6: DataType.float32, + 7: DataType.float64, + 8: DataType.uint16, +} +MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -class GPTMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): + +class LegacyMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -28,12 +38,10 @@ def __init__( self, name: str, prefix: pathlib.Path | str, - num_documents: int | None = None, - num_tokens: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init(self, name: str, prefix: pathlib.Path | str) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -54,9 +62,6 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None _ = struct.unpack(" tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._prefix) - def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): + def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): @@ -168,7 +171,7 @@ def get_document( token_ids = token_ids.to(torch.int64) if parameters is not None and parameters.use_loss_masking_spans: assert self._spans is not None - # TODO: ====== Store in range format (begin, end) ====== + # Convert to in range format (begin, end). sample_spans = RangeSample( [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size ).crop(begin, end) @@ -182,7 +185,7 @@ def get_document( raise ValueError("Failed to read chosen spans from memmap dataset.") elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") - # TODO: ====== Store in range format ====== + # Convert to in range format (begin, end). chosen_spans = RangeSample( [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], sample_size, @@ -222,95 +225,3 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() - - @classmethod - def write_dataset( - cls, - prefix: pathlib.Path | str, - documents: typing.Iterable[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]], - ) -> None: - # Initialize metadata - dtype = None - num_documents = 0 - lengths = [] - pointers = [] - offset = 0 - # number of spans for each document - num_spans = [] - spans = [] - chosen_spans = [] - rejected_spans = [] - - prefix = pathlib.Path(prefix) - prefix.parent.mkdir(parents=True, exist_ok=True) - - # Write the binary data file (.bin) lazily - with prefix.with_suffix(".bin").open("wb") as bin_stream: - for token_ids, loss_masking_spans, chosen_span, rejected_span in documents: - # Infer dtype from the first document - if dtype is None: - dtype = token_ids.dtype - assert dtype is not None, "Document dtype could not be inferred from the data." - - # Ensure all documents have the same dtype - assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." - - # Write document to binary file - bin_stream.write(token_ids.numpy().tobytes(order="C")) - - # Update metadata - doc_length = len(token_ids) - lengths.append(doc_length) - pointers.append(offset) - if loss_masking_spans is not None: - num_spans.append(len(loss_masking_spans)) - spans.append(loss_masking_spans) - if chosen_span is not None: - chosen_spans.append(chosen_span) - if rejected_span is not None: - rejected_spans.append(rejected_span) - offset += doc_length * dtype.itemsize - num_documents += 1 - - # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) - pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: - spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) - - # Write the index file (.idx) - with prefix.with_suffix(".idx").open("wb") as idx_stream: - idx_stream.write(MEMMAP_INDEX_HEADER) - # Indicates the version - # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) - # Data type - idx_stream.write(struct.pack(" SampleType: pass - @abc.abstractmethod def __len__(self) -> int: """ - Number of samples in the dataset. + Number of documents in the dataset. + Note: this default implementation is slow and should be overridden when possible. + """ + return len(self.get_document_sizes()) + + @property + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + Note: this default implementation is slow and should be overridden when possible. """ + return self.get_document_sizes().sum().item() def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.sampled import SampledIndexedDataset @@ -108,6 +117,13 @@ def __init__( def __len__(self) -> int: return self._dataset_splits[-1].item() + @property + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + """ + return sum(dataset.num_tokens for dataset in self._datasets) + def get_document_sizes(self) -> torch.Tensor: # TODO: This can be really big. return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py new file mode 100644 index 000000000..4b1930dd3 --- /dev/null +++ b/fast_llm/data/dataset/memmap.py @@ -0,0 +1,105 @@ +import json +import pathlib +import typing + +import numpy as np +import torch + +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig, MemmapWriter, Sample + +FILE_HEADER = b"fast_llm_prepared_dataset" + + +class MemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): + """ + A memory map dataset, which handles lazy loading of a pre-processed dataset. + """ + + def __init__( + self, + name: str, + path: pathlib.Path | str, + ): + self._init(name, path) + + def _init(self, name: str, path: pathlib.Path | str) -> None: + super().__init__() + self._name = name + self._path = path + + with self._path.open("rb") as stream: + # Very file type. + assert stream.read(len(FILE_HEADER)) == FILE_HEADER + # Go to reader configs. + stream.seek(int.from_bytes(stream.read(8), signed=False)) + # Read the reader config. + reader_config = MemmapIndexDatasetReaderConfig.from_dict( + json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) + ) + + self._memmap = np.memmap(self._path, mode="r") + self._reader = reader_config.get_reader(memoryview(self._memmap)) + + def __getstate__(self) -> tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]: + # We pass the reader config to force its import in data loader workers. + return self._name, self._path, self._reader.config + + def __setstate__(self, state: tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]): + name, path, _ = state + self._init(name, path) + + def __del__(self): + if hasattr(self, "_memmap"): + self._memmap._mmap.close() # noqa + del self._memmap + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + if end is None: + end = self._reader.get_document_size(index) + return self._reader.get_document(index, begin, end) + + @property + def name(self) -> str: + return self._name + + def __len__(self) -> int: + return len(self._reader) + + @property + def num_tokens(self) -> int: + return self._reader.num_tokens + + def get_document_sizes(self) -> torch.Tensor: + return self._reader.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._reader.get_document_size(index) + + @classmethod + def write_dataset( + cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter] + ) -> MemmapIndexDatasetReaderConfig: + # TODO: Match `writer_class` with `SampleType`? + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as stream: + # Write the file type header. + stream.write(FILE_HEADER) + # Leave space for a pointer to the reader config. + # We write the config at the end since we don't know it yet. + start = stream.tell() + stream.seek(start + 8) + # Write the data. + reader_config = writer_class.write_dataset(stream, documents) + # Write the reader config. + config_offset = stream.tell() + reader_config_bytes = json.dumps(reader_config.to_dict()).encode("utf-8") + stream.write(len(reader_config_bytes).to_bytes(4, signed=False)) + stream.write(reader_config_bytes) + # Write a pointer to the reader config. + stream.seek(start) + stream.write(config_offset.to_bytes(8, signed=False)) + return reader_config diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 46a518cd0..d51a68746 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -414,7 +414,6 @@ def __getitem__(self, index: int) -> SampleType: document_sampling_index += 1 token_count += document_size - # TODO: ====== Better way to get the class method? ====== return documents[0].from_documents(documents) @property diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index 160fccafc..a774fc3de 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -7,7 +7,6 @@ @config_class(registry=True, dynamic_type={RunnableConfig: "prepare"}) class DatasetPreparatorConfig(RunnableConfig): - preparator_name: typing.ClassVar[str] @classmethod def get_dataset_preparator_class(cls) -> type["DatasetPreparator"]: diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..9bf292033 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,50 +1,68 @@ +import functools import os import pathlib import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -MEMMAP_DTYPES = { - 1: DataType.uint8, - 2: DataType.int8, - 3: DataType.int16, - 4: DataType.int32, - 5: DataType.int64, - 6: DataType.float32, - 7: DataType.float64, - 8: DataType.uint16, -} -MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} -MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" - - -@config_class(registry=True) -class SourceSchemaConfig(Config): - pass - - -@config_class(dynamic_type={SourceSchemaConfig: "text_column"}) -class TextColumnConfig(SourceSchemaConfig): - input_column: str = Field( + + +@config_class() +class LanguageModelSourceConfig(Config): + """ + A schema holding the name of each relevant column in the dataset. + Setting optional entries will enable the associated feature. + """ + + text: str = Field( default="text", desc="Field of the dataset to use.", hint=FieldHint.optional, ) - loss_masking_spans_column: None | str = Field( + loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + chosen_span: None | str = Field( + default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + ) + rejected_span: None | str = Field( + default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + ) + + @functools.cached_property + def columns(self) -> list[str]: + columns = [self.text] + if self.has_loss_masking_span: + columns.append(self.loss_masking_spans) + if self.has_preference_spans: + columns.extend([self.chosen_span, self.rejected_span]) + return columns + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + return self.loss_masking_spans is not None + + @functools.cached_property + def has_preference_spans(self) -> bool: + Assert.eq(self.chosen_span is None, self.rejected_span is None) + return self.chosen_span is not None + + def _validate(self): + super()._validate() + if self.has_preference_spans and self.has_loss_masking_span: + raise ValueError(f"Can not enable both loss masking and preference spans.") @config_class() class GPTHuggingfaceDatasetConfig(Config): - path: str = Field( + path: str | pathlib.Path = Field( default=None, desc="Name or path of the dataset.", hint=FieldHint.core, @@ -69,16 +87,10 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Split of the dataset to use.", hint=FieldHint.optional, ) - source_schema: SourceSchemaConfig = Field( + source_schema: LanguageModelSourceConfig = Field( desc="Configuration for the data source.", hint=FieldHint.optional, ) - chosen_text: None | str = Field( - default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional - ) - rejected_text: None | str = Field( - default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional - ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." @@ -95,6 +107,11 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Disable disk space check. Useful for environments where disk space is not accurately reported.", hint=FieldHint.optional, ) + load_from_disk: bool = Field( + default=False, + desc="Use the `load_from_disk` method for datasets saved with `save_to_disk`.", + hint=FieldHint.feature, + ) @config_class() @@ -132,8 +149,6 @@ def _validate(self) -> None: @config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): - preparator_name: typing.ClassVar[str] = "gpt_memmap" - output_path: pathlib.Path = Field( default=None, desc="Output directory for the processed dataset.", @@ -143,27 +158,14 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for distributed processing.", hint=FieldHint.feature, ) - tokens_per_shard: int = Field( - default=10**9, - desc="Approximate number of tokens per shard.", + documents_per_shard: int = Field( + default=10**6, + desc="Target number of documents per shard.", hint=FieldHint.feature, - valid=check_field(Assert.geq, 10**5), ) - loading_workers: int = Field( + num_workers: int = Field( default=1, - desc="Number of workers in load_dataset() call.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), - ) - tokenize_workers: int = Field( - default=1, - desc="Number of workers for tokenization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), - ) - saving_workers: int = Field( - default=1, - desc="Number of processes for saving the data.", + desc="Number of parallel workers.", hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) @@ -183,10 +185,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): ) def _validate(self) -> None: - assert self.tokenizer.path is not None - if self.dataset.data_type is not None: - Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() + assert self.tokenizer.path is not None @classmethod def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 274bbf1b0..18d4d46e2 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,5 +1,8 @@ +import collections +import enum import json import logging +import math import multiprocessing import pathlib import shutil @@ -18,182 +21,53 @@ BlendedDatasetConfig, DatasetSliceConfig, IndexedDatasetConfig, + MemmapDatasetConfig, SampledDatasetConfig, ) -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig -from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preprocessing.tokenizer import Tokenizer +from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig +from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type +from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum logger = logging.getLogger(__name__) +class SpanType(enum.StrEnum): + loss_masking = "loss_masking" + chosen = "chosen" + rejected = "rejected" + + class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): _tokenizer: Tokenizer _data_type: DataType - _text_column: str - _loss_masking_spans_column: str | None _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample + _config: GPTMemmapDatasetPreparatorConfig - def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans, dtype=np.int32).reshape(-1, 2), - ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) - ] - ] - ), - ) - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "token_spans": token_spans, - "num_tokens": num_tokens, - } - - def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - packed_texts = [] - chosen_spans = [] - rejected_spans = [] - - for conv_history, chosen_text, rejected_text in zip( - batch[self._config.dataset.field], - batch[self._config.dataset.chosen_text], - batch[self._config.dataset.rejected_text], - ): - # compute chosen span - full_chosen_text = conv_history + chosen_text + self._tokenizer.tokenizer.eos_token - chosen_span = [len(conv_history), len(full_chosen_text) - 1] - offset = len(full_chosen_text) - chosen_spans.append(chosen_span) - - # compute rejected span - full_rejected_text = self._tokenizer.tokenizer.bos_token + conv_history + rejected_text - rejected_span = [ - offset + len(self._tokenizer.tokenizer.bos_token + conv_history), - offset + len(full_rejected_text) - 1, - ] - rejected_spans.append(rejected_span) - - # pack texts - packed_text = full_chosen_text + full_rejected_text - - assert ( - packed_text[chosen_span[0] : chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token - ), f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" - - assert ( - packed_text[rejected_span[0] : rejected_span[1] + 1] == rejected_text - ), f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" - packed_texts.append(packed_text) - - # tokenize with spans - input_ids, chosen_token_spans, rejected_token_spans = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans[0], dtype=np.int32), - np.array( - [token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32 - ), # adding 1 to end for eos token - ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) - for text, chosen_span, rejected_span in zip(packed_texts, chosen_spans, rejected_spans) - ] - ] - ), - ) - - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "chosen_token_spans": chosen_token_spans, - "rejected_token_spans": rejected_token_spans, - "num_tokens": num_tokens, - } - - def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: - shard_idx, shard_dataset = args - prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" - shard_output_path = self._config.output_path / prefix - - def _document_generator(): - # TODO: Yield `LanguageModelSample` - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), - None, - None, - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - None, - torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), - torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - None, - None, - None, - ) - - GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) - - return GPTMemmapDatasetConfig.from_dict( - { - "type": "memmap", - "path": prefix, - "num_documents": len(shard_dataset), # Use the length of the shard dataset directly - "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), - } - ) + def __init__(self, config: ConfigType): + super().__init__(config) + self._source_schema: LanguageModelSourceConfig = self._config.dataset.source_schema def _load_dataset(self) -> datasets.Dataset: - dataset = datasets.load_dataset( - path=self._config.dataset.path, - name=self._config.dataset.config_name, - data_dir=self._config.dataset.data_directory, - data_files=self._config.dataset.data_files, - split=self._config.dataset.split, - num_proc=self._config.loading_workers, - trust_remote_code=self._config.dataset.trust_remote_code, - ) + if self._config.dataset.load_from_disk: + dataset = datasets.load_from_disk(self._config.dataset.path)[self._config.dataset.split] + else: + dataset = datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + data_dir=self._config.dataset.data_directory, + data_files=self._config.dataset.data_files, + split=self._config.dataset.split, + num_proc=self._config.num_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ) assert isinstance(dataset, datasets.Dataset) return dataset @@ -261,6 +135,7 @@ def run(self) -> None: # Initialize distributed processing if self._config.distributed.world_size > 1: + log_main_rank(f"> Initializing distributed process groups ...") torch.distributed.init_process_group( backend=self._config.distributed.backend, rank=self._config.distributed.rank, @@ -270,111 +145,153 @@ def run(self) -> None: # Prepare output directory self._config.output_path.mkdir(parents=True, exist_ok=True) - if pathlib.Path(self._config.dataset.path).is_dir(): - # Dataset is already downloaded, load from disk + log_main_rank(f"> Loading dataset `{self._config.dataset.path}` ...") + if self._config.distributed.world_size == 1: dataset = self._load_dataset() + elif self._config.distributed.rank == 0: + # Load first on rank 0 to prevent parallel downloads. + dataset = self._load_dataset() + torch.distributed.barrier() else: - # Dataset is not downloaded, download on rank 0 - if self._config.distributed.rank == 0: - dataset = self._load_dataset() - - # Synchronize processes to wait for the download to finish on rank 0 - if self._config.distributed.world_size > 1: - torch.distributed.barrier() - + torch.distributed.barrier() # Load the downloaded dataset on remaining ranks - if self._config.distributed.rank != 0: - dataset = self._load_dataset() - - # Synchronize processes to wait for the dataset to load on remaining ranks - if self._config.distributed.world_size > 1: - torch.distributed.barrier() + dataset = self._load_dataset() - assert isinstance(dataset, datasets.Dataset) dataset = dataset.shard( num_shards=self._config.distributed.world_size, index=self._config.distributed.rank, ) - # Set data column and loss masking spans column based on source schema - if isinstance(self._config.dataset.source_schema, TextColumnConfig): - self._text_column = self._config.dataset.source_schema.input_column - self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column - else: - raise ValueError( - f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." - ) + for column_name in self._source_schema.columns: + if column_name not in dataset.column_names: + raise ValueError(f"Dataset does not have field '{column_name}'.") - if self._text_column not in dataset.column_names: - raise ValueError(f"Dataset does not have field '{self._text_column}'.") - - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( - self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None - ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") - if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): - raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") - - # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans - elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: - if self._config.dataset.chosen_text not in dataset.column_names: - raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") - if self._config.dataset.rejected_text not in dataset.column_names: - raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.") - tokenize_fn = self._tokenize_preference_batch_with_spans - else: - tokenize_fn = self._tokenize_batch - - # Tokenize the dataset in parallel - tokenized_dataset = dataset.map( - tokenize_fn, - batched=True, - num_proc=self._config.tokenize_workers, - desc="Tokenizing batches", + # Split dataset into shards based on number of tokens + num_shards = math.ceil(len(dataset) / self._config.documents_per_shard) + shards = [(i, dataset.shard(num_shards=num_shards, index=i)) for i in range(num_shards)] + + log_main_rank(f"> Preparing samples on {self._config.num_workers} workers ...") + + # Use multiprocessing to save each shard in parallel on all ranks + with multiprocessing.Pool(processes=self._config.num_workers) as pool: + dataset_and_reader_configs = pool.map(self._prepare_shard, shards) + + log_main_rank(f"> Generating dataset config ...") + self.generate_config_yaml_for_sharded_dst(dataset_and_reader_configs) + + def _prepare_shard( + self, args: tuple[int, datasets.Dataset] + ) -> tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]: + shard_index, shard_dataset = args + file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" + + reader_config = MemmapDataset.write_dataset( + self._config.output_path / file_name, + ( + self._prepare_sample(sample) + for sample in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_index}", unit="docs") + ), + LanguageModelWriter, ) + return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config + + def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: + # TODO: ======= Extract so we can use elsewhere? (ex. inference) ====== + text = sample[self._source_schema.text] + all_spans = [] + if self._source_schema.has_loss_masking_span: + # TODO: ====== What is the exact input format? ====== + # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. + loss_masking_spans = _sort_spans( + (SpanType.loss_masking, (begin, last + 1)) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) + .reshape(-1, 2) + .tolist() + ) + all_spans.extend(loss_masking_spans) - # Calculate total number of tokens - total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + if self._source_schema.has_preference_spans: + # TODO: ===== Was `self._config.dataset.field` (bug?) ====== + full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token + full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] + # compute chosen span + chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] - # Split dataset into shards based on number of tokens - num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) - shards = [ - (i, tokenized_dataset.shard(num_shards=num_shards, index=i)) - for i in tqdm.tqdm(range(num_shards), desc="Creating shards") - ] + # compute rejected span + rejected_span = [ + ( + SpanType.rejected, + ( + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), + ), + ) + ] + # pack texts + text = full_chosen_text + full_rejected_text + all_spans.extend(chosen_spans + rejected_span) + + # Sort the spans by location (begin), keeping track of their type. + # Note: overlapping spans are not supported (explicit assertion in the tokenizer). + span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) + # Tokenize the text, and determine the span locations in the tokenized text. + tokens, token_spans = self._tokenizer.tokenize_with_spans( + text, True, True, text_spans=spans, data_type=self._data_type + ) - # Use multiprocessing to save each shard in parallel on all ranks - with multiprocessing.Pool(processes=self._config.saving_workers) as pool: - dataset_configs = pool.map(self._save_shard, shards) + # Gather token spans by type. + token_spans_by_type = collections.defaultdict(list) + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) - self.generate_config_yaml_for_sharded_dst(dataset_configs) + sample_size = len(tokens) - def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDatasetConfig]) -> None: + return LanguageModelSample( + TokenSample(tokens, [sample_size]), + ( + RangeSample(token_spans_by_type[SpanType.loss_masking], sample_size) + if self._source_schema.has_loss_masking_span + else None + ), + ( + RangeSample(token_spans_by_type[SpanType.chosen], sample_size) + if self._source_schema.has_preference_spans + else None + ), + ( + # `tokenize_with_spans` excludes the final eod token from the rejected span, but we want to include it. + RangeSample([(begin, end + 1) for begin, end in token_spans_by_type[SpanType.rejected]], sample_size) + if self._source_schema.has_preference_spans + else None + ), + ) + + def generate_config_yaml_for_sharded_dst( + self, dataset_and_reader_configs: list[tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]] + ) -> None: # Gather dataset_dicts from all ranks to rank 0 if self._config.distributed.world_size > 1: if self._config.distributed.rank == 0: - all_dataset_configs = [None] * self._config.distributed.world_size - torch.distributed.gather_object(dataset_configs, all_dataset_configs, dst=0) - dataset_configs = [item for sublist in all_dataset_configs for item in sublist] + all_dataset_and_reader_configs = [None] * self._config.distributed.world_size + torch.distributed.gather_object(dataset_and_reader_configs, all_dataset_and_reader_configs, dst=0) + dataset_and_reader_configs = [item for sublist in all_dataset_and_reader_configs for item in sublist] else: - torch.distributed.gather_object(dataset_configs, [], dst=0) + torch.distributed.gather_object(dataset_and_reader_configs, [], dst=0) if self._config.distributed.rank == 0: # Create the config file(s) on rank 0 + dataset_configs, reader_configs = zip(*dataset_and_reader_configs) if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, reader_configs, self._config.splits, self._config.output_path ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" ) else: self._save_dataset_config( - self._blend_dataset_configs(dataset_configs), self._config.output_path / f"fast_llm_config.yaml" + self._blend_dataset_configs(dataset_configs, reader_configs), + self._config.output_path / f"fast_llm_config.yaml", ) # Save metadata on rank 0 @@ -397,7 +314,9 @@ def _save_dataset_config( @classmethod def _blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]] + cls, + dataset_configs: list[MemmapDatasetConfig[_sample_type]], + reader_configs: list[MemmapIndexDatasetReaderConfig], ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] @@ -405,19 +324,20 @@ def _blend_dataset_configs( { "type": "blended", "datasets": dataset_configs, - "weights": [dataset_config.num_tokens for dataset_config in dataset_configs], + "weights": [reader_config.num_tokens for reader_config in reader_configs], } ) @classmethod def _split_and_blend_dataset_configs( cls, - dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]], + dataset_configs: list[MemmapDatasetConfig[_sample_type]], + reader_configs: list[MemmapIndexDatasetReaderConfig], splits: dict[str, int | float], output_path: pathlib.Path, ) -> dict[str, SampledDatasetConfig[_sample_type]]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() - dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] + dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) dataset_cumsums = padded_cumsum(dataset_probabilities).tolist() dataset_splits = {} @@ -425,7 +345,9 @@ def _split_and_blend_dataset_configs( for split_index, split_name in enumerate(splits): datasets_in_split = [] dataset_tokens_in_split = [] - for dataset_index, dataset_config in enumerate(dataset_configs): + for dataset_index, (dataset_config, reader_config) in enumerate( + zip(dataset_configs, reader_configs, strict=True) + ): split_begin_in_dataset = max( (split_cumsum[split_index] - dataset_cumsums[dataset_index]) / dataset_probabilities[dataset_index], @@ -445,17 +367,17 @@ def _split_and_blend_dataset_configs( # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + Assert.eq(sizes_cumsum[-1], reader_config.num_tokens) + begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens) + end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens) if end_index > begin_index: datasets_in_split.append( DatasetSliceConfig[cls._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], - "begin": begin_index / dataset_config.num_documents, - "end": end_index / dataset_config.num_documents, + "begin": begin_index / len(reader_config), + "end": end_index / len(reader_config), } ) ) @@ -483,6 +405,10 @@ def _split_and_blend_dataset_configs( return dataset_splits +def _sort_spans(spans: typing.Iterable[tuple[SpanType, tuple[int, int]]]) -> list[tuple[SpanType, tuple[int, int]]]: + return sorted(spans, key=lambda span: span[1][0]) + + def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: left = cumsum.searchsorted(value, side="right") if left == len(cumsum): diff --git a/fast_llm/data/preprocessing/__init__.py b/fast_llm/data/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py new file mode 100644 index 000000000..70291bcaa --- /dev/null +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -0,0 +1,196 @@ +import pathlib +import typing + +from fast_llm.config import Config, Configurable, Field, FieldHint, config_class +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import numpy as np + import torch + + +@config_class() +class TokenizerConfig(Config): + """ + Configuration for the tokenizer. + The tokenizer is needed for FIM and dataset preparation. + """ + + path: pathlib.Path = Field( + default=None, + desc="Path to the tokenizer file.", + hint=FieldHint.core, + ) + bos_token: str | None = Field( + default=None, + desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", + hint=FieldHint.core, + ) + max_vocab_size: int | None = Field( + default=None, + desc="Constrain output tokens to a specific range. Used for testing.", + hint=FieldHint.testing, + ) + + def get_tokenizer(self) -> "Tokenizer": + from fast_llm.data.preprocessing.tokenizer import Tokenizer + + return Tokenizer(self) + + +class Tokenizer[ConfigType: TokenizerConfig](Configurable[ConfigType]): + """ + A wrapper around Huggingface (transformers) tokenizer. + """ + + def __init__(self, config: ConfigType): + super().__init__(config) + from transformers import AutoTokenizer + + log_main_rank(f"> loading tokenizer from {config.path} ...") + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=self._config.path, + errors="replace", + max_len=None, + trust_remote_code=True, + use_fast=True, + ) + if self._config.bos_token is not None: + self.tokenizer.bos_token = self._config.bos_token + if self.tokenizer.eos_token_id is None: + raise ValueError("Tokenizer does not have an EOS token.") + if self.tokenizer.bos_token_id is None: + raise ValueError("Tokenizer does not have an BOS token.") + self.eod_id = self.tokenizer.eos_token_id + self.bod_id = self.tokenizer.bos_token_id + + @property + def vocab_size(self) -> int: + return len(self.tokenizer) + + @property + def vocab(self) -> dict[str, int]: + return self.tokenizer.vocab + + @property + def inv_vocab(self) -> dict[int, str]: + return self._inv_vocab + + def tokenize( + self, text: str, begin: bool = True, end: bool = True, data_type: DataType = DataType.int64 + ) -> "torch.Tensor": + import torch + + tokens = torch.tensor( + ([self.bod_id] if begin else []) + + self.tokenizer.encode(text, add_special_tokens=False) + + ([self.eod_id] if end else []), + dtype=data_type.torch, + ) + if self._config.max_vocab_size is not None: + tokens %= self._config.max_vocab_size + return tokens + + def tokenize_with_spans( + self, + text: str, + begin: bool = True, + end: bool = True, + *, + text_spans: list[tuple[int, int]], + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[tuple[int, int]]]: + """ + Perform span-aware tokenization and return the tokenized input_ids along with token spans. + """ + if not text_spans: + return self.tokenize(text, begin, end, data_type=data_type), [] + input_ids, token_splits = self.tokenize_with_splits( + text, begin, end, text_splits=[split for splits in text_spans for split in splits], data_type=data_type + ) + return input_ids, [(begin, end) for begin, end in zip(token_splits[::2], token_splits[1::2], strict=True)] + + def tokenize_with_splits( + self, + text: str, + begin: bool = True, + end: bool = True, + *, + text_splits: list[int], + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[int]]: + if not text_splits: + return self.tokenize(text, begin, end, data_type=data_type), [] + import torch + + Assert.eq(sorted(text_splits), text_splits) + input_ids = [] + text_splits = [0, *text_splits, len(text)] + token_splits = [] + total_tokens = 0 + + for i, (split_begin, split_end) in enumerate(zip(text_splits[:-1], text_splits[1:])): + input_ids.append( + split_tokens := self.tokenize( + text[split_begin:split_end], + begin and i == 0, + end and i == len(text_splits) - 2, + data_type=data_type, + ) + ) + total_tokens += len(split_tokens) + token_splits.append(total_tokens) + + return torch.cat(input_ids), token_splits[:-1] + + def detokenize( + self, tokens: "int | list[int] | np.ndarray | torch.Tensor", begin: bool = False, end: bool = False + ) -> str: + tokens = self._remove_delimiters(tokens, begin, end) + return self.tokenizer.decode(tokens) + + def detokenize_with_spans( + self, tokens: "torch.Tensor", begin: bool = False, end: bool = False, *, token_spans: list[tuple[int, int]] + ) -> tuple[str, list[tuple[int, int]]]: + if not token_spans: + return self.detokenize(tokens, begin, end), [] + text, text_splits = self.detokenize_with_splits( + tokens, begin, end, token_splits=[split for splits in token_spans for split in splits] + ) + return text, [(begin, end) for begin, end in zip(text_splits[::2], text_splits[1::2], strict=True)] + + def detokenize_with_splits( + self, tokens: "torch.Tensor", begin: bool = False, end: bool = False, *, token_splits: list[int] + ) -> tuple[str, list[int]]: + if not token_splits: + return self.detokenize(tokens, begin, end), [] + Assert.eq(sorted(token_splits), token_splits) + tokens = self._remove_delimiters(tokens, begin, end) + texts = [] + token_splits = [0, *(token_split - begin for token_split in token_splits), len(tokens)] + text_splits = [] + total_characters = 0 + + for i, (split_begin, split_end) in enumerate(zip(token_splits[:-1], token_splits[1:])): + texts.append(split_text := self.detokenize(tokens[split_begin:split_end])) + total_characters += len(split_text) + text_splits.append(total_characters) + + return "".join(texts), text_splits[:-1] + + def _remove_delimiters( + self, token_ids: "int | list[int] | np.ndarray | torch.Tensor", begin: bool = False, end: bool = False + ): + if begin: + Assert.eq(token_ids[0], self.bod_id) + token_ids = token_ids[1:] + if end: + Assert.eq(token_ids[-1], self.eod_id) + token_ids = token_ids[:-1] + return token_ids + + @property + def eod(self): + return self.eod_id diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 031002101..aaa321efd 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -1,6 +1,11 @@ import abc +import io +import pathlib import typing +from fast_llm.config import Config, Configurable, Field, config_class +from fast_llm.utils import Assert + if typing.TYPE_CHECKING: import torch @@ -40,3 +45,193 @@ def crop(self, begin: int, end: int) -> typing.Self: def to_device_(self, device: "torch.device | str"): pass + + +@config_class(registry=True) +class MemmapReaderBaseConfig(Config): + """ + Configuration for a memmap reader or reader-like object. + Note: `MemmapDataset` requires a `MemmapIndexedDatasetReader`. + Other readers need to be nested within a `MemmapIndexedDatasetReader` + Note: Reader configs are not typical configs, and do not need to be located in a separate `config.py` file. + """ + + _abstract = True + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is MemmapReaderBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass, necessary for loading configs where some components could be absent. + return NullReaderConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + def get_reader(self, buffer: memoryview) -> "MemmapReader|None": + raise NotImplementedError() + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, including header and footer. Used for self-validation. + """ + raise NotImplementedError() + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) +class NullReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a dynamically disabled reader. + """ + + _abstract = False + + def get_reader(self, buffer: memoryview) -> None: + return None + + @property + def expected_buffer_size(self) -> int: + return 0 + + +@config_class(registry=True) +class MemmapReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a standard memmap reader. + """ + + # Data location in the file. + begin: int = Field() + end: int = Field() + # Constant strings for alignment safety. + header: typing.ClassVar[bytes] + footer: typing.ClassVar[bytes] + + @property + def reader_class(self) -> "type[MemmapReader]": + raise NotImplementedError() + + def get_reader(self, buffer: memoryview) -> "MemmapReader": + return self.reader_class(self, buffer) + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, including header and footer. Used for self-validation. + """ + return self._expected_buffer_size + len(self.header) + len(self.footer) + + @property + def _expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, excluding header and footer. Used for self-validation. + """ + raise NotImplementedError() + + @property + def writer_class(self) -> "type[MemmapWriter]": + raise NotImplementedError() + + def get_writer(self, stream: io.BufferedWriter) -> "MemmapWriter": + return self.writer_class(stream) + + def _validate(self): + super()._validate() + Assert.eq(self.end - self.begin, self.expected_buffer_size) + + +@config_class() +class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): + """ + Configuration for a standard memmap reader matching the indexed dataset interface, i.e., + consisting of a list of documents of known lengths. + """ + + def __len__(self) -> int: + raise NotImplementedError() + + @property + def num_tokens(self) -> int: + raise NotImplementedError() + + @property + def reader_class(self) -> "type[MemmapIndexedDatasetReader]": + raise NotImplementedError() + + def get_reader( + self, + buffer: memoryview, + ) -> "MemmapIndexedDatasetReader": + return self.reader_class(self, buffer) + + +class MemmapReader[ConfigType: MemmapReaderConfig](Configurable[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config) + buffer_begin = self._config.begin + len(self._config.header) + buffer_end = self._config.end - len(self._config.footer) + Assert.eq(buffer[self._config.begin : buffer_begin].tobytes(), self._config.header) + Assert.eq(buffer[buffer_end : self._config.end].tobytes(), self._config.footer) + self._buffer = buffer[buffer_begin:buffer_end] + + @abc.abstractmethod + def get_document(self, index: int, begin: int, end: int) -> Sample: + pass + + +class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): + def __len__(self) -> int: + return len(self._config) + + @property + def num_tokens(self) -> int: + return self._config.num_tokens + + @abc.abstractmethod + def get_document_sizes(self) -> "torch.Tensor": + pass + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + pass + + +class MemmapWriter(abc.ABC): + def __init__(self, stream: io.BufferedWriter | pathlib.Path): + self._owns_stream = isinstance(stream, pathlib.Path) + if self._owns_stream: + stream = stream.open("wb") + self._stream = stream + + def __enter__(self): + self._begin = self._stream.tell() + self._stream.write(self._get_config_class().header) + return self + + def write(self, document: Sample): + assert hasattr(self, "_begin") and not hasattr(self, "_end") + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._stream.write(self._get_config_class().footer) + self._end = self._stream.tell() + if self._owns_stream: + self._stream.close() + + @classmethod + @abc.abstractmethod + def _get_config_class(cls) -> type[MemmapReaderConfig]: + pass + + def get_config(self, offset: int = 0) -> MemmapReaderConfig: + assert hasattr(self, "_end") + return self._get_config(self._begin + offset, self._end + offset) + + @abc.abstractmethod + def _get_config(self, begin: int, end: int): + pass + + @classmethod + def write_dataset(cls, stream: io.BufferedWriter, documents: typing.Iterable[Sample]) -> MemmapReaderConfig: + with cls(stream) as writer: + for document in documents: + writer.write(document) + return writer.get_config() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index f30188553..6f485bf84 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -1,8 +1,23 @@ +import io +import pathlib +import tempfile import typing -from fast_llm.data.sample.abstract import Batch, Sample -from fast_llm.data.sample.range import RangeBatch, RangeSample -from fast_llm.data.sample.token import TokenBatch, TokenSample +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexDatasetReaderConfig, + MemmapIndexedDatasetReader, + MemmapReaderBaseConfig, + MemmapWriter, + NullReaderConfig, + Sample, +) +from fast_llm.data.sample.range import RangeBatch, RangeSample, RangeWriter +from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.utils import Assert class LanguageModelSample(Sample): @@ -74,9 +89,9 @@ def to_samples(self) -> list[LanguageModelSample]: LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( self.tokens.to_samples(), - self.loss_masking_spans.to_samples(), - self.chosen_spans.to_samples(), - self.rejected_spans.to_samples(), + None if self.loss_masking_spans is None else self.loss_masking_spans.to_samples(), + None if self.chosen_spans is None else self.chosen_spans.to_samples(), + None if self.rejected_spans is None else self.rejected_spans.to_samples(), strict=True, ) ] @@ -105,3 +120,190 @@ def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.I def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: return None if sample_or_batch is None else sample_or_batch.crop(begin, end) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) +class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"lm begin" + footer: typing.ClassVar[bytes] = b"lm end" + tokens: TokenReaderConfig = Field() + # Using dynamic type for optional readers for enabling/disabling + loss_masking_spans: MemmapReaderBaseConfig = Field() + chosen_spans: MemmapReaderBaseConfig = Field() + rejected_spans: MemmapReaderBaseConfig = Field() + + def __len__(self) -> int: + return len(self.tokens) + + @property + def num_tokens(self) -> int: + return self.tokens.num_tokens + + @property + def reader_class(self) -> "type[LanguageModelReader]": + return LanguageModelReader + + @property + def writer_class(self) -> "type[LanguageModelWriter]": + return LanguageModelWriter + + @property + def _expected_buffer_size(self) -> int: + return ( + self.tokens.expected_buffer_size + + self.loss_masking_spans.expected_buffer_size + + self.chosen_spans.expected_buffer_size + + self.rejected_spans.expected_buffer_size + ) + + +class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. + self._tokens = self._config.tokens.get_reader(buffer) + self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) + self._chosen_spans = self._config.chosen_spans.get_reader(buffer) + self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + + @property + def num_tokens(self) -> int: + return self._config.tokens.num_tokens + + def get_document(self, index: int, begin: int, end: int) -> Sample: + return LanguageModelSample( + self._tokens.get_document(index, begin, end), + None if self._loss_masking_spans is None else self._loss_masking_spans.get_document(index, begin, end), + None if self._chosen_spans is None else self._chosen_spans.get_document(index, begin, end), + None if self._rejected_spans is None else self._rejected_spans.get_document(index, begin, end), + ) + + def get_document_sizes(self) -> torch.Tensor: + return self._tokens.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._tokens.get_document_size(index) + + +class LanguageModelWriter(MemmapWriter): + _has_loss_masking_spans: bool | None = None + _has_preference_spans: bool | None = None + + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + + self._directory = tempfile.TemporaryDirectory() + self._path = pathlib.Path(self._directory.name) + # We write intermediate results in separate files so we don't need to iterate over the dataset multiple times. + self._token_writer = TokenWriter(self._path.joinpath("tokens")).__enter__() + self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() + self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() + self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() + return self + + def write(self, document: LanguageModelSample): + super().write(document) + # Write tokens. + self._token_writer.write(document.tokens) + + # Ensure either all samples have loss masking spans or none of them do. + if self._has_loss_masking_spans is None: + self._has_loss_masking_spans = document.loss_masking_spans is not None + else: + Assert.eq(self._has_loss_masking_spans, document.loss_masking_spans is not None) + + # Write loss masking spans. + if self._has_loss_masking_spans: + self._loss_masking_span_writer.write(document.loss_masking_spans) + + # All sample must either have both chosen and rejected spans, or neither. + if self._has_preference_spans is None: + self._has_preference_spans = document.chosen_spans is not None + else: + Assert.eq(self._has_preference_spans, document.chosen_spans is not None) + Assert.eq(self._has_preference_spans, document.rejected_spans is not None) + + # Write preference spans. + if self._has_preference_spans: + self._chosen_spans_writer.write(document.chosen_spans) + self._rejected_spans_writer.write(document.rejected_spans) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._token_writer.__exit__(exc_type, exc_val, exc_tb) + self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) + self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) + self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) + + if exc_type is None: + # A dummy config so we can verify the begin and end offsets. + config = self._get_config(self._begin, None) + _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) + + if self._has_loss_masking_spans: + _copy_chunked( + self._path.joinpath("loss_masking_spans"), + self._stream, + config.loss_masking_spans.begin, + config.loss_masking_spans.end, + ) + if self._has_preference_spans: + _copy_chunked( + self._path.joinpath("chosen_spans"), + self._stream, + config.chosen_spans.begin, + config.chosen_spans.end, + ) + _copy_chunked( + self._path.joinpath("rejected_spans"), + self._stream, + config.rejected_spans.begin, + config.rejected_spans.end, + ) + + self._directory.cleanup() + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[LanguageModelReaderConfig]: + return LanguageModelReaderConfig + + def _get_config(self, begin: int, end: int | None): + tokens = self._token_writer.get_config(begin + len(LanguageModelReaderConfig.header)) + offset = tokens.end + if self._has_loss_masking_spans: + loss_masking_spans = self._loss_masking_span_writer.get_config(offset) + offset = loss_masking_spans.end + else: + loss_masking_spans = NullReaderConfig() + if self._has_preference_spans: + chosen_spans = self._chosen_spans_writer.get_config(offset) + offset = chosen_spans.end + rejected_spans = self._rejected_spans_writer.get_config(offset) + offset = rejected_spans.end + else: + chosen_spans = NullReaderConfig() + rejected_spans = NullReaderConfig() + + if end is None: + end = offset + len(LanguageModelReaderConfig.footer) + + return LanguageModelReaderConfig( + begin=begin, + end=end, + tokens=tokens, + loss_masking_spans=loss_masking_spans, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + ) + + +def _copy_chunked(path: pathlib.Path, stream: io.BufferedWriter, expected_begin: int, expected_end: int): + # Copy temporary file content in chunks of 100 MB. + Assert.eq(stream.tell(), expected_begin) + with path.open("rb") as input_stream: + while data := input_stream.read(100000000): + stream.write(data) + Assert.eq(stream.tell(), expected_end) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index d121a38b6..c3a035376 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -1,7 +1,18 @@ import typing -from fast_llm.data.sample.abstract import Batch, Sample -from fast_llm.utils import get_unique +import numpy as np +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapReader, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.utils import Assert, get_unique class RangeSample(Sample): @@ -47,3 +58,78 @@ def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: def to_samples(self) -> list[RangeSample]: return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) +class RangeReaderConfig(MemmapReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"range begin" + footer: typing.ClassVar[bytes] = b"range end" + num_documents: int = Field() + num_ranges: int = Field() + + @property + def reader_class(self) -> "type[RangeReader]": + return RangeReader + + @property + def writer_class(self) -> "type[RangeWriter]": + return RangeWriter + + @property + def _expected_buffer_size(self) -> int: + return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize + + +class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._ranges = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_ranges * 2, + ).view(-1, 2) + self._count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_documents + 1, + offset=self._ranges.nbytes, + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + sample_size = end - begin + cropped_ranges = ( + (max(begin_ - begin, 0), min(end_ - begin, sample_size)) + for begin_, end_ in self._ranges[self._count_cumsums[index] : self._count_cumsums[index + 1]].tolist() + ) + return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + + +class RangeWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._count_cumsum = [0] + return self + + def write(self, document: RangeSample): + super().write(document) + self._stream.write(np.array(document.ranges, dtype=np.int32).tobytes(order="C")) + self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + Assert.lt(self._count_cumsum[-1], np.iinfo(np.int32).max) + self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[RangeReaderConfig]: + return RangeReaderConfig + + def _get_config(self, begin: int, end: int): + return RangeReaderConfig( + begin=begin, + end=end, + num_documents=len(self._count_cumsum) - 1, + num_ranges=self._count_cumsum[-1], + ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 62d1c0e67..706b5053a 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -1,8 +1,18 @@ import typing +import numpy as np import torch -from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexedDatasetReader, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -73,3 +83,87 @@ def crop(self, begin: int, end: int) -> typing.Self: def to_device_(self, device: "torch.device | str"): # Also standardize the dtype while we're here. self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) +class TokenReaderConfig(MemmapReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"token begin" + footer: typing.ClassVar[bytes] = b"token end" + num_documents: int = Field() + num_tokens: int = Field() + data_type: DataType = Field() + + def __len__(self) -> int: + return self.num_documents + + @property + def reader_class(self) -> "type[TokenReader]": + return TokenReader + + @property + def writer_class(self) -> "type[TokenWriter]": + return TokenWriter + + @property + def _expected_buffer_size(self) -> int: + return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize + + +class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._tokens = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_tokens, + ) + self._size_cumsums = torch.frombuffer( + self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._tokens.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + begin_ = self._size_cumsums[index].item() + # Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues. + return TokenSample(self._tokens[begin_ + begin : begin_ + end].to(torch.int64), [end - begin]) + + def get_document_sizes(self) -> torch.Tensor: + return self._size_cumsums[1:] - self._size_cumsums[:-1] + + def get_document_size(self, index: int) -> int: + return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() + + +class TokenWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + return self + + def write(self, document: TokenSample): + super().write(document) + if self._data_type is None: + self._data_type = document.tokens.dtype + else: + Assert.eq(self._data_type, document.tokens.dtype) + self._stream.write(document.tokens.numpy().tobytes()) + self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[TokenReaderConfig]: + return TokenReaderConfig + + def _get_config(self, begin: int, end: int): + return TokenReaderConfig( + begin=begin, + end=end, + num_documents=len(self._size_cumsum) - 1, + num_tokens=self._size_cumsum[-1], + data_type=DataType.from_torch(self._data_type), + ) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py deleted file mode 100644 index c74586207..000000000 --- a/fast_llm/data/tokenizer.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import torch -from transformers import AutoTokenizer - -from fast_llm.data.config import TokenizerConfig -from fast_llm.engine.config_utils.run import log_main_rank - - -class Tokenizer: - """ - A wrapper around Huggingface (transformers) tokenizer. - """ - - def __init__(self, config: TokenizerConfig): - log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=config.path, - errors="replace", - max_len=None, - trust_remote_code=True, - use_fast=True, - ) - if config.bos_token is not None: - self.tokenizer.bos_token = config.bos_token - if self.tokenizer.eos_token_id is None: - raise ValueError("Tokenizer does not have an EOS token.") - if self.tokenizer.bos_token_id is None: - raise ValueError("Tokenizer does not have an BOS token.") - self.eod_id = self.tokenizer.eos_token_id - self.bod_id = self.tokenizer.bos_token_id - - @property - def vocab_size(self) -> int: - return len(self.tokenizer) - - @property - def vocab(self) -> dict[str, int]: - return self.tokenizer.vocab - - @property - def inv_vocab(self) -> dict[int, str]: - return self._inv_vocab - - def tokenize(self, text: str, begin=True, end=True) -> list[int]: - return ( - ([self.bod_id] if begin else []) - + self.tokenizer.encode(text, add_special_tokens=False) - + ([self.eod_id] if end else []) - ) - - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: - """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. - """ - input_ids = [] - token_spans = [] - char_pos = 0 - beginning_of_text = True - - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans - - def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: - return self.tokenizer.decode(token_ids) - - @property - def eod(self): - return self.eod_id diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 1a0fed91b..27709a8bb 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -50,9 +50,13 @@ def from_torch(cls, dtype: "torch.dtype") -> "DataType": return _TORCH_DTYPE_MAP_INV[dtype] @classmethod - def from_numpy(cls, dtype: "np.dtype") -> "DataType": + def from_numpy(cls, dtype: "np.dtype | type[np.number]") -> "DataType": + import numpy as np + if not _NUMPY_DTYPE_MAP_INV: _set_numpy_dtype_map() + if isinstance(dtype, np.dtype): + dtype = dtype.type return _NUMPY_DTYPE_MAP_INV[dtype] @classmethod diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 051163084..163a9459c 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -106,7 +106,7 @@ def _get_runnable(self) -> typing.Callable[[], None]: return self.run def run(self) -> None: - raise NotImplementedError() + self._get_runnable()() def _show[ T diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index f8dfd4825..df7ab0f51 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -2,7 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.data.config import TokenizerConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 7ab0b9ff6..c5ae48eba 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -37,7 +37,7 @@ def compute_dpo_loss( reference_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - # TODO: ====== Shouldn't the sigmoid be computed independently for each document? + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? ======= losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) if grad_output is None: diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 322932664..414314627 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,6 +2,7 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ +from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 1f9feceb4..83675ac74 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -161,21 +161,22 @@ def rms_close_relative(x, y, threshold, min_threshold=0): assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" @staticmethod - def all_equal(x, y): + def all_equal(x, *args): import torch # Make it work for lists and numpy arrays. x = torch.as_tensor(x) - y = torch.as_tensor(y) - - Assert.eq(x.shape, y.shape) - neq = x != y - if neq.any().item(): # noqa - index = None if x.numel() == 1 else torch.where(neq) # noqa - raise AssertionError( - f"Tensors have {index[0].numel()} different entries out of " - f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}" - ) + for arg in args: + arg = torch.as_tensor(arg) + + Assert.eq(x.shape, arg.shape) + neq = x != arg + if neq.any().item(): # noqa + index = None if x.numel() == 1 else torch.where(neq) # noqa + raise AssertionError( + f"Tensors have {index[0].numel()} different entries out of " + f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + ) @staticmethod def all_different(x, y): diff --git a/tests/data/common.py b/tests/data/common.py index e6ab8a265..ac8d8023c 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -4,26 +4,18 @@ import numpy as np import torch -from fast_llm.config import Field, FieldHint, NoAutoValidate, config_class +from fast_llm.config import NoAutoValidate from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import ( - IndexedDatasetConfig, - SampledDatasetConfig, - SamplingConfig, - SamplingParameters, - ShufflingType, -) +from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig, ShufflingType from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset -from fast_llm.data.sample.abstract import Sample from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.utils.global_variables import TEST_VOCAB_SIZE def get_sampling_data( @@ -33,7 +25,7 @@ def get_sampling_data( cache_directory: pathlib.Path | None = None, phase=PhaseType.training, sequence_length: int = 512, - vocab_size=TEST_VOCAB_SIZE, + vocab_size: int | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, @@ -73,7 +65,7 @@ def get_test_data_and_compare_samples( shuffle: ShufflingType = ShufflingType.epoch, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, - vocab_size=TEST_VOCAB_SIZE, + vocab_size: int | None = None, expected_samples: dict[str, list[list[int]]] | list[list[int]], ) -> GPTData: distributed_config = DistributedConfig(seed=87522) @@ -111,45 +103,30 @@ def get_test_data_and_compare_samples( for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - Assert.all_equal(tokens[phase].to(torch.int64), expected_samples_) + Assert.all_equal(tokens[phase], expected_samples_) return data -def compare_indexed_dataset( +def compare_indexed_dataset_tokens( dataset: IndexedDataset, length: int, num_tokens: int, expected_samples: dict[int, list[int]], - loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) sizes = dataset.get_document_sizes() - # Assert.eq(sizes.sum(), num_tokens) + Assert.eq(sizes.sum(), num_tokens, dataset.num_tokens) Assert.all_equal( [len(dataset.get_document(i).tokens.tokens) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.int64)) - if loss_masking_spans: - for i, loss_masking_span in loss_masking_spans.items(): - print(i) - Assert.eq( - dataset.get_document( - i, - parameters=GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True - ), - ).loss_masking_spans.ranges, - loss_masking_spans[i], - ) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample)) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal( - torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]).to(torch.int64), expected_samples - ) + Assert.all_equal(torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]), expected_samples) def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): @@ -185,54 +162,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s for index in range(sampled._parameters.num_samples) ] token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) - Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: Assert.all_equal(token_ids, expected_samples) return token_ids - - -@config_class(dynamic_type={SampledDatasetConfig: "mock_memmap"}) -class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - num_documents: int | None = Field( - default=None, - desc="Expected number of documents in the dataset.", - hint=FieldHint.core, - ) - num_tokens_per_document: int | None = Field( - default=None, - desc="Expected number of tokens in the dataset.", - hint=FieldHint.optional, - ) - path: pathlib.Path = Field(default=".") - - def build(self) -> "IndexedDataset": - return MockMemmapDataset(self) - - @property - def num_tokens(self) -> int: - return self.num_documents * self.num_tokens_per_document - - -class MockMemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): - def __init__(self, config: MockGPTMemmapDatasetConfig): - self._config = config - - @property - def name(self) -> str: - return "mock_memmap" - - def __len__(self) -> int: - return self._config.num_documents - - def get_document_sizes(self) -> torch.Tensor: - return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) - - def get_document_size(self, index: int) -> int: - return self._config.num_tokens_per_document - - def get_document( - self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: - raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 0099cb50b..88ecf2c99 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -12,17 +12,11 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX - -_DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" - - -def _get_test_dataset_mix_1(): - return get_test_dataset(prefix=_DATASET_PREFIX_MIX_1, seed=2345) +from tests.utils.dataset import get_alt_test_dataset, get_common_test_dataset def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, np.ndarray]: + # Alternate implementation for blending. probs = np.array(probs) dataset_index = np.zeros(num_samples) sample_index = np.zeros(num_samples) @@ -37,25 +31,25 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, GPT_BLENDED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [4628, 7392, 920, 79, 1322, 387], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [387, 4224, 87, 2713, 423, 324], - [3036, 253, 207, 2968, 4536, 1178], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [49152, 83, 80, 20452, 45, 93], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [93, 90, 39, 6, 75, 9], + [58, 22885, 93, 37, 92, 76], ] GPT_BLENDED_MIXED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], + [49152, 46, 10, 819, 19, 45], [916, 6683, 7685, 1277, 5106, 378], - [1790, 80, 6506, 1735, 542, 88], + [45, 69, 17, 86, 38826, 15], [3359, 6803, 780, 4561, 669, 7878], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], [6920, 2218, 2921, 3963, 7606, 6904], - [207, 4700, 549, 79, 417, 3036], + [1793, 1, 1746, 38, 27, 58], ] @@ -112,38 +106,21 @@ def test_blending(probs): def test_gpt_blended(): # Make sure dataset blending works and check for unintended changes in behavior. - get_test_dataset() - _get_test_dataset_mix_1() + _, config, _ = get_common_test_dataset() + _, alt_config, _ = get_alt_test_dataset() sampled = get_dataset_config( - { + dataset_config := { "type": "blended", - "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, - ], + "datasets": [config, alt_config], "weights": [0.75, 0.25], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) - -def test_gpt_blended_data(): - get_test_dataset() - _get_test_dataset_mix_1() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "blended", - "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, - ], - "weights": [0.75, 0.25], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_BLENDED_SAMPLES, @@ -152,34 +129,25 @@ def test_gpt_blended_data(): def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() sampled = get_dataset_config( - { + dataset_config := { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + config, {"type": "random"}, ], "weights": [0.6, 0.4], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) - -def test_gpt_blended_mixed_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "blended", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX}, {"type": "random"}], - "weights": [0.6, 0.4], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, + vocab_size=8192, expected_samples=GPT_BLENDED_MIXED_SAMPLES, ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 5335e01c0..d7e750c8b 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,56 +1,48 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( - compare_indexed_dataset, + compare_indexed_dataset_tokens, compare_sampled_dataset, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, ) -from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_SAMPLES, COMMON_DATASET_TOKENS +from tests.utils.dataset import get_common_test_dataset GPT_CONCATENATED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [3036, 253, 207, 2968, 4536, 1178], - [1178, 3291, 317, 277, 2679, 89], - [89, 542, 395, 583, 684, 554], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [58, 22885, 93, 37, 92, 76], + [76, 29, 19, 17365, 93, 46], + [46, 83, 17211, 1, 785, 1023], ] def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() + memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() dataset = get_dataset_config( - {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, + dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], ).build() - compare_indexed_dataset( + compare_indexed_dataset_tokens( dataset, - 3 * MEMMAP_DATASET_LENGTH, - 3 * MEMMAP_DATASET_TOKENS, - {j * MEMMAP_DATASET_LENGTH + i: sample for j in range(3) for i, sample in MEMMAP_DATASET_SAMPLES.items()}, + 3 * COMMON_DATASET_LENGTH, + 3 * COMMON_DATASET_TOKENS, + {j * COMMON_DATASET_LENGTH + i: sample for j in range(3) for i, sample in COMMON_DATASET_SAMPLES.items()}, ) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_CONCATENATED_SAMPLES) - -def test_gpt_concatenate_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "concatenated", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_CONCATENATED_SAMPLES, diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py deleted file mode 100644 index c149e1395..000000000 --- a/tests/data/test_dataset_from_file.py +++ /dev/null @@ -1,12 +0,0 @@ -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX - - -def test_dataset_from_file(): - get_test_dataset() - dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} - dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() - compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 438c5e7e3..0600c5258 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -5,34 +5,30 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX, TOKENIZER_PATH +from tests.utils.dataset import get_common_test_dataset +from tests.utils.global_variables import TOKENIZER_PATH GPT_FIM_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [86, 89, 7876, 80, 49152, 87], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [86, 89, 1178, 49152, 87, 49152], - [86, 49152, 1178, 64, 89, 900], - [86, 49152, 89, 542, 395, 89], + [46, 10, 819, 19, 45, 88], + [45, 69, 17, 86, 38826, 15], + [86, 89, 32348, 64, 49152, 87], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [86, 89, 37, 92, 76, 49152], + [86, 49152, 76, 29, 19, 89], + [86, 49152, 46, 83, 17211, 1], ] def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. - sampling_config = get_sampling_data( - 8, - sequence_length=5, - vocab_size=49157, - ) + sampling_config = get_sampling_data(8, sequence_length=5) sampled = get_dataset_config( - { + dataset_config := { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": config, "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", @@ -44,26 +40,9 @@ def test_gpt_fim(): ).build_and_sample(sampling_config) compare_sampled_dataset(sampled, GPT_FIM_SAMPLES) - -def test_gpt_fim_data(): - get_test_dataset() get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, - "tokenizer": {"path": TOKENIZER_PATH}, - "rate": 0.5, - "prefix_token": "w", - "middle_token": "x", - "pad_token": "y", - "suffix_token": "z", - } - }, - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_FIM_SAMPLES, - vocab_size=49157, ) diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py new file mode 100644 index 000000000..443a26819 --- /dev/null +++ b/tests/data/test_loss_masking_spans.py @@ -0,0 +1,80 @@ +import datasets +import pytest + +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.utils import Assert +from tests.data.common import get_dataset_config +from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_TEXT +from tests.utils.dataset import get_test_dataset_with_loss_masking_spans +from tests.utils.global_variables import TOKENIZER_NAME + +DATASET_WITH_SPAN_TOKENS = 45577 +DATASET_WITH_SPAN_SAMPLES = { + 27: [49152, 63, 82, 11, 27799, 49152], + 30: [49152, 31, 85, 78, 27, 1448, 62, 43, 49152], + 31: [49152, 60, 55, 80, 30, 85, 22, 18, 49152], + 77: [49152, 73, 80, 85, 52, 22, 46, 5, 88, 78, 49152], + 87: [49152, 52, 42536, 11, 71, 49152], +} +HF_LOSS_MASKING_SPANS = { + 27: [[0, 1]], + 30: [[0, 1]], + 31: [[0, 0], [2, 2], [5, 5]], + 77: [[0, 0], [2, 2], [5, 5], [7, 7]], + 87: [[0, 0], [3, 3]], +} +TOKEN_LOSS_MASKING_SPANS = { + 27: [(1, 3)], + 30: [(1, 3)], + 31: [(1, 2), (3, 4), (6, 7)], + 77: [(1, 2), (3, 4), (6, 7), (8, 9)], + 87: [(1, 2), (3, 4)], +} + + +@pytest.mark.slow +def test_gpt_data_with_spans(): + _, config, hf_path = get_test_dataset_with_loss_masking_spans() + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + + hf_dataset = datasets.load_from_disk(hf_path)["train"] + tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() + + # Check global stats. + Assert.eq(len(dataset), len(hf_dataset), COMMON_DATASET_LENGTH) + Assert.eq(dataset.num_tokens, DATASET_WITH_SPAN_TOKENS) + + for index in range(0, 200, 8): + expected_text = hf_dataset[index]["text"] + expected_text_spans = [(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]] + expected_tokens, expected_spans = tokenizer.tokenize_with_spans( + hf_dataset[index]["text"], + text_spans=[(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]], + ) + document = dataset.get_document( + index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) + ) + + # Compare tokens and token spans. + Assert.all_equal(document.tokens.tokens, expected_tokens) + Assert.eq(document.loss_masking_spans.ranges, expected_spans) + + # Compare text. + text, text_spans = tokenizer.detokenize_with_spans( + document.tokens.tokens, True, True, token_spans=document.loss_masking_spans.ranges + ) + Assert.eq(text, expected_text) + Assert.eq(text_spans, expected_text_spans) + + # Check some numerical values. + for index in DATASET_WITH_SPAN_SAMPLES: + Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) + Assert.eq(hf_dataset[index]["loss_masking_spans"], HF_LOSS_MASKING_SPANS[index]) + document = dataset.get_document( + index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) + ) + Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_SPAN_SAMPLES[index]) + Assert.eq(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py deleted file mode 100644 index ca887f3c1..000000000 --- a/tests/data/test_memmap.py +++ /dev/null @@ -1,49 +0,0 @@ -import pathlib - -import pytest - -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE - -MEMMAP_DATASET_LENGTH = 6153 -MEMMAP_DATASET_TOKENS = 508327 -MEMMAP_DATASET_SAMPLES = { - 9: [], - 10: [80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], - 13: [78, 727, 74, 317, 1358, 89], - 15: [78], -} - - -@pytest.mark.parametrize("cache_directory", (None, pathlib.Path(DATASET_SAMPLING_CACHE) / "test_memmap")) -def test_gpt_memmap(cache_directory): - # Make sure the memmap dataset works and check for unintended changes in behavior. - get_test_dataset() - dataset = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() - compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) - - -MEMMAP_DATASET_SPANS = { - 9: [], - 10: [(0, 2), (2, 7), (7, 10)], - 13: [(0, 2)], - 15: [], -} - -_DATASET_PREFIX_SPANS = DATASET_CACHE / "with_spans" / "dataset" - - -def test_gpt_data_with_spans(): - get_test_dataset(prefix=_DATASET_PREFIX_SPANS, max_spans=5) - dataset = get_dataset_config( - { - "type": "memmap", - "path": _DATASET_PREFIX_SPANS, - }, - GPTMemmapDatasetConfig, - ).build() - compare_indexed_dataset( - dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_SPANS - ) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py new file mode 100644 index 000000000..ef18337eb --- /dev/null +++ b/tests/data/test_preference_spans.py @@ -0,0 +1,107 @@ +import datasets +import numpy as np +import pytest +import torch + +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.utils import Assert +from tests.data.common import get_dataset_config +from tests.data.test_preparator import COMMON_DATASET_LENGTH +from tests.utils.dataset import get_test_dataset_with_preference_spans +from tests.utils.global_variables import TOKENIZER_NAME + +DATASET_WITH_PREFERENCE_SPAN_TOKENS = 62163 +DATASET_WITH_PREFERENCE_SPAN_TEXT = { + 27: ["`", "s,", "uh"], + 30: ["@v", "o{hf_dataset[index]["answer"]}<|endoftext|>", + ) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py deleted file mode 100644 index 601abcf99..000000000 --- a/tests/data/test_prepare_gpt_memmap.py +++ /dev/null @@ -1,201 +0,0 @@ -import json -import pathlib -import tempfile - -import numpy as np -import pytest -import torch - -from fast_llm.data.dataset.config import IndexedDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig -from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.utils import Assert -from tests.data.common import MockGPTMemmapDatasetConfig # Noqa - - -def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: - config = GPTMemmapDatasetPreparatorConfig.from_dict( - { - "output_path": output_path, - "dataset": {"path": dataset_path_name}, - "tokenizer": {"path": "no_tokenizer"}, - }, - {}, - ) - return config.get_dataset_preparator_class()(config=config) - - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_dataset(dtype): - documents = [ - (torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)), None, None, None) - for _ in range(100) - ] - with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, (tokens, _, _, _) in enumerate(documents): - Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens.to(torch.int64)) - - -def _generate_valid_span(max_seq_length): - return np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist() - - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_preference_dataset(dtype): - documents = [ - ( - torch.from_numpy(np.random.randint(1000, size=100).astype(dtype)), - None, - _generate_valid_span(100), - _generate_valid_span(100), - ) - for _ in range(50) - ] - with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) - parameters = GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True - ) - for i, (token_ids, _, (chosen_begin, chosen_end), (rejected_begin, rejected_end)) in enumerate(documents): - document = dataset.get_document(i, parameters=parameters) - Assert.all_equal(document.tokens.tokens, token_ids.to(torch.int64)) - Assert.eq(document.chosen_spans.ranges, [(chosen_begin, chosen_end + 1)]) - Assert.eq(document.rejected_spans.ranges, [(rejected_begin, rejected_end + 1)]) - - -def test_load_metadata_from_hub(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata() - croissant_path = pathlib.Path(local_folder) / "croissant.json" - assert croissant_path.is_file() - metadata = json.load(croissant_path.open("r")) - assert metadata["url"] == "https://huggingface.co/datasets/lhoestq/demo1" - - -def test_absent_metadata_from_hub(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - get_preparator(local_folder, "allenai/dolma")._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file() - - -def test_load_metadata_local(): - with ( - tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, - tempfile.TemporaryDirectory(suffix="test") as local_folder, - ): - metadata = {"name": "test"} - json.dump(metadata, (pathlib.Path(dataset_folder) / "croissant.json").open("w")) - get_preparator(local_folder, dataset_folder)._save_croissant_metadata() - croissant_path = pathlib.Path(local_folder) / "croissant.json" - assert croissant_path.is_file() - assert json.loads(croissant_path.open("r").read()) == metadata - - -def test_absent_metadata_local(): - with ( - tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, - tempfile.TemporaryDirectory(suffix="test") as local_folder, - ): - get_preparator(local_folder, dataset_folder)._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file() - - -DATASET_DICT_0 = { - "type": "mock_memmap", - "num_documents": 500, - "num_tokens_per_document": 300, -} -DATASET_DICT_1 = { - "type": "mock_memmap", - "num_documents": 1500, - "num_tokens_per_document": 100, -} - - -def test_split_dataset(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0], - {"training": 3, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": { - "type": "slice", - "dataset": dataset_config_0.to_dict(), - "begin": 0, - "end": 0.75, - }, - "validation": { - "type": "slice", - "dataset": dataset_config_0.to_dict(), - "begin": 0.75, - "end": 1, - }, - }, - ) - - -def test_split_datasets_0(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], - {"training": 1, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": dataset_config_0.to_dict(), - "validation": dataset_config_1.to_dict(), - }, - ) - - -def test_split_datasets_1(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": { - "type": "blended", - "datasets": [ - dataset_config_0.to_dict(), - { - "type": "slice", - "dataset": dataset_config_1.to_dict(), - "begin": 0, - "end": 0.5, - }, - ], - "weights": [2 / 3, 1 / 3], - }, - "validation": { - "type": "slice", - "dataset": dataset_config_1.to_dict(), - "begin": 0.5, - "end": 1, - }, - }, - ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 8e5c61904..7a31358b9 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -16,22 +16,16 @@ def test_gpt_random_dataset(): # Make sure the random dataset works and check for unintended changes in behavior. - sampled = get_dataset_config({"type": "random"}, GPTRandomDatasetConfig).build_and_sample( - get_sampling_data(4, sequence_length=7) + sampled = get_dataset_config(config := {"type": "random"}, GPTRandomDatasetConfig).build_and_sample( + get_sampling_data(4, sequence_length=7, vocab_size=8192) ) compare_sampled_dataset(sampled, RANDOM_DATASET_EXPECTED_SAMPLES) - -def test_gpt_random_data(): + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "random", - } - } - }, + {"datasets": {"training": config}}, 4, sequence_length=7, + vocab_size=8192, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 58f4d3dab..2d102be01 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -3,7 +3,7 @@ import torch from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -14,8 +14,7 @@ get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.dataset import get_common_test_dataset try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -26,37 +25,28 @@ GPT_MEMMAP_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [3036, 253, 207, 2968, 4536, 1178], - [1178, 3291, 317, 277, 2679, 89], - [89, 542, 395, 583, 684, 554], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [58, 22885, 93, 37, 92, 76], + [76, 29, 19, 17365, 93, 46], + [46, 83, 17211, 1, 785, 1023], ] def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. - get_test_dataset() - sampled = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build_and_sample( - get_sampling_data(8, sequence_length=5) - ) + _, config, _ = get_common_test_dataset() + sampled = get_dataset_config( + dataset_config := config, GPTDatasetFromFileConfig[LanguageModelSample] + ).build_and_sample(get_sampling_data(8, sequence_length=5)) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) - -def test_gpt_sampled_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "memmap", - "path": DATASET_PREFIX, - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_MEMMAP_SAMPLES, @@ -169,7 +159,6 @@ def test_gpt_sample_padding(): sampling = get_sampling_data( num_samples=len(expected_samples), sequence_length=sequence_length, - vocab_size=vocab_size, seed=seed, shuffle=ShufflingType.disabled, truncate_documents=False, diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 3c6ae10d4..224b18270 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,67 +1,67 @@ from fast_llm.data.dataset.config import DatasetSliceConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( - compare_indexed_dataset, + compare_indexed_dataset_tokens, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.data.test_preparator import COMMON_DATASET_SAMPLES +from tests.utils.dataset import get_common_test_dataset GPT_SLICE_TRAINING_SAMPLES = [ - [80, 268, 79, 260, 207, 3086], - [3086, 80, 413, 4872, 4602, 207], - [207, 7208, 1489, 776, 3514, 269], - [269, 73, 7367, 267, 477, 3126], + [49152, 20, 59, 81, 15, 54], + [54, 76, 7909, 44, 41, 1], + [1, 71, 28, 10, 42, 15963], + [15963, 80, 59, 86, 4, 74], ] GPT_SLICE_VALIDATION_SAMPLES = [ - [1886, 317, 5621, 3173, 330, 284], - [284, 2846, 706, 89, 80, 2047], - [2047, 207, 2449, 1423, 65, 985], - [985, 683, 4917, 87, 477, 481], - [481, 695, 947, 5871, 2344, 87], - [87, 489, 207, 489, 269, 356], - [356, 727, 7800, 4078, 243, 3712], - [3712, 86, 476, 80, 2547, 7390], + [49152, 3, 5621, 27, 7859, 13009], + [13009, 73, 32, 29, 32, 3], + [3, 89, 15, 45, 25, 75], + [75, 52, 13366, 88, 54, 19], + [19, 2, 74, 23, 92, 24747], + [24747, 42, 6, 477, 21, 47], + [47, 92, 31, 30, 463, 64], + [64, 23, 11, 56, 23555, 85], ] def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() + memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() # samples[9:18] dataset = get_dataset_config( - {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, + {"type": "slice", "dataset": memmap_config, "begin": 0.025, "end": 0.1}, DatasetSliceConfig[LanguageModelSample], ).build() - compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) + compare_indexed_dataset_tokens(dataset, 75, 3399, {i - 25: sample for i, sample in COMMON_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) validate_indexed_dataset_sampling(sampled, GPT_SLICE_VALIDATION_SAMPLES) - -def test_gpt_slice_data(): + # Test in data with multiple phases. get_test_data_and_compare_samples( { "datasets": { "training": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": memmap_config, "begin": 0, - "end": 0.0015, + "end": 0.025, }, "validation": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, - "begin": 0.0015, - "end": 0.003, + "dataset": memmap_config, + "begin": 0.025, + "end": 0.1, }, "test": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, - "begin": 0.003, + "dataset": memmap_config, + "begin": 0.1, "end": 1, }, } diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 489f5e1c1..3a90745eb 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -61,12 +62,12 @@ def reference_dpo_loss( def test_dpo_loss(): - torch.manual_seed(0) - logits = torch.randn((10, 50, 100), requires_grad=True) - reference_model_logits = torch.randn((10, 50, 100)) - targets = torch.randint(0, 100, (10, 50)) + random_state = np.random.RandomState(0) + logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32).requires_grad_() + reference_model_logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32) + targets = torch.from_numpy(random_state.randint(0, 100, (10, 50))) - spans = get_random_spans(10, 10, 50) + spans = get_random_spans(10, 10, 50, random_state) fastllm_loss, fast_llm_grad = compute_dpo_loss( logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 7447e395a..f3ce65966 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -1,21 +1,29 @@ import os +import pathlib +import struct import typing +import datasets import numpy as np import pytest +import torch +import yaml from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import SampledDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.config import MemmapDatasetConfig, SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import get_model_test_dataset +from tests.utils.dataset import get_common_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX +from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_NAME from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -26,6 +34,30 @@ except ImportError: _extension_available = False +MEGATRON_DATASET_PREFIX = DATASET_CACHE / "megatron_dataset/dataset" + + +def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): + if not ( + prefix.with_suffix(".idx").is_file() + and prefix.with_suffix(".bin").is_file() + and prefix.parent.joinpath("fast_llm_config.yaml").is_file() + ): + _, _, hf_path = get_common_test_dataset() + hf_dataset = datasets.load_from_disk(hf_path)["train"] + tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() + samples = [ + LanguageModelSample( + TokenSample((tokenizer.tokenize(document["text"]) % MODEL_TEST_VOCAB_SIZE).to(torch.uint16)) + ) + for document in hf_dataset + ] + + MegatronMemmapDataset.write_dataset(prefix, samples) + yaml.safe_dump( + {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") + ) + @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.megatron) @@ -35,11 +67,12 @@ def test_megatron(run_distributed_script, model_testing_config, run_test_script_ # Prevent Megatron from complaining. env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env["NVTE_FLASH_ATTN"] = "0" - get_model_test_dataset() + get_megatron_test_dataset() run_distributed_script( [ "Megatron-LM/pretrain_gpt.py", *model_testing_config.megatron_args, + f"--data-path={MEGATRON_DATASET_PREFIX}", f"--structured-logs-dir={path}", f"--data-cache-path={path}", ], @@ -69,7 +102,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare="megatron", config_args=[ "model.distributed.compute_dtype=fp32", - f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PREFIX}}}', + f'data.datasets.training={{"type":"megatron","path":{MEGATRON_DATASET_PREFIX}}}', "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], @@ -82,25 +115,83 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co @config_class(dynamic_type={SampledDatasetConfig: "megatron"}) -class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): +class MegatronDatasetConfig[SampleType: LanguageModelSample](MemmapDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: str = Field( desc="Dataset path (prefix).", hint=FieldHint.core, ) - def build(self) -> "GPTMemmapDataset": - return GPTMegatronMemmapDataset( - str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens - ) - - -class GPTMegatronMemmapDataset(GPTMemmapDataset): - def sample(self, sampling: GPTSamplingData) -> "MegatronGPTSampledIndexedDataset": - return MegatronGPTSampledIndexedDataset(self, sampling) - - -class MegatronGPTSampledIndexedDataset(SampledDataset): + def build(self) -> "LegacyMemmapDataset[SampleType]": + return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path) + + +class MegatronMemmapDataset(LegacyMemmapDataset): + def sample(self, sampling: GPTSamplingData) -> "MegatronSampledIndexedDataset": + return MegatronSampledIndexedDataset(self, sampling) + + @classmethod + def write_dataset( + cls, + prefix: pathlib.Path | str, + documents: typing.Iterable[LanguageModelSample], + ) -> None: + # Initialize metadata + dtype = None + num_documents = 0 + lengths = [] + pointers = [] + offset = 0 + + prefix = pathlib.Path(prefix) + prefix.parent.mkdir(parents=True, exist_ok=True) + + # Write the binary data file (.bin) lazily + with prefix.with_suffix(".bin").open("wb") as bin_stream: + for document in documents: + token_ids = document.tokens.tokens + # Infer dtype from the first document + if dtype is None: + dtype = token_ids.dtype + assert dtype is not None, "Document dtype could not be inferred from the data." + + # Ensure all documents have the same dtype + assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." + + # Write document to binary file + bin_stream.write(token_ids.numpy().tobytes(order="C")) + + # Update metadata + doc_length = len(token_ids) + lengths.append(doc_length) + pointers.append(offset) + offset += doc_length * dtype.itemsize + num_documents += 1 + + # Finalize metadata arrays + lengths = np.array(lengths, dtype=np.int32) + pointers = np.array(pointers, dtype=np.int64) + + # Write the index file (.idx) + with prefix.with_suffix(".idx").open("wb") as idx_stream: + idx_stream.write(MEMMAP_INDEX_HEADER) + # Version + idx_stream.write(struct.pack(" dict[str, str]: + texts_ = [] + chosen_spans = [] + rejected_spans = [] + for text in texts: + # Split in three non-empty_chunks + splits = np.sort(random_state.choice(range(1, len(text) - 1), 2, replace=False)).tolist() + texts_.append(text[: splits[0]]) + chosen_spans.append(text[splits[0] : splits[1]]) + rejected_spans.append(text[splits[1] :]) + return {"text": texts_, "chosen_span": chosen_spans, "rejected_span": rejected_spans} + + +def _get_hf_test_dataset( seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - max_spans: int = 0, + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 99, + min_loss_masking_spans: int = 0, + max_loss_masking_spans: int = 0, + has_preference_spans: bool = False, ): - download_santacoder_tokenizer() + random_state = np.random.RandomState(seed) + # Generate random document sizes (character count). + texts, document_sizes = get_random_text(num_documents, min_document_size, max_document_size, random_state) - if not ( - prefix.with_suffix(".idx").is_file() - and prefix.with_suffix(".bin").is_file() - and prefix.parent.joinpath("fast_llm_config.yaml").is_file() - ): - import transformers - - texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) + if has_preference_spans: + dataset_dict = get_random_preference_spans(texts, random_state) + else: + dataset_dict: dict[str, typing.Any] = {"text": texts} - samples = [ - ( - torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), - None, - None, - None, - ) - for document in texts - ] - if max_spans > 0: - spans = get_random_spans( - len(samples), max_spans, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), seed - ) - samples = [ - ( - tokens, - torch.tensor(sample_spans, dtype=torch.int32).reshape(-1, 2), - None, - None, - ) - for (tokens, _, _, _), sample_spans in zip(samples, spans, strict=True) - ] - - GPTMemmapDataset.write_dataset(prefix, samples) - yaml.safe_dump( - {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") + if max_loss_masking_spans > 0: + dataset_dict["loss_masking_spans"] = get_random_spans( + document_sizes, min_loss_masking_spans, max_loss_masking_spans, random_state, use_last_format=True ) + return datasets.Dataset.from_dict(dataset_dict) -def get_model_test_dataset( - prefix: pathlib.Path = MODEL_DATASET_PREFIX, - vocab_size: int = MODEL_TEST_VOCAB_SIZE, + +def _get_test_dataset( + path: pathlib.Path, + seed: int, + tokenizer_path: str = TOKENIZER_PATH, + vocab_size: int | None = None, + documents_per_shard: int = 10**6, + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 99, + min_loss_masking_spans: int = 0, + max_loss_masking_spans: int = 0, + has_preference_spans: bool = False, + splits: dict[str, float] | None = None, ): - return get_test_dataset(prefix=prefix, vocab_size=vocab_size) + config_paths = ( + [path / "fast_llm_config.yaml"] + if splits is None + else [path / f"fast_llm_config_{split}.yaml" for split in splits] + ) + hf_path = path / "hf" + + if not all(config_path.is_file() for config_path in config_paths): + dataset = _get_hf_test_dataset( + seed=seed, + num_documents=num_documents, + min_document_size=min_document_size, + max_document_size=max_document_size, + min_loss_masking_spans=min_loss_masking_spans, + max_loss_masking_spans=max_loss_masking_spans, + has_preference_spans=has_preference_spans, + ) + datasets.DatasetDict({"train": dataset}).save_to_disk(hf_path) + source_schema = {"text": "text"} + if max_loss_masking_spans > 0: + source_schema["loss_masking_spans"] = "loss_masking_spans" + if has_preference_spans: + source_schema["chosen_span"] = "chosen_span" + source_schema["rejected_span"] = "rejected_span" + + download_santacoder_tokenizer() + preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( + { + "dataset": { + "path": hf_path, + "load_from_disk": True, + "source_schema": source_schema, + }, + "tokenizer": {"path": tokenizer_path, "max_vocab_size": vocab_size}, + "output_path": path, + "documents_per_shard": documents_per_shard, + "splits": splits, + } + ) + preparator_config.run() + + config = ( + {"type": "file", "path": config_paths[0]} + if splits is None + else { + split: {"type": "file", "path": config_path} + for split, config_path in zip(splits, config_paths, strict=True) + } + ) + return path, config, hf_path + + +def get_common_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "common_dataset", seed=1234) + + +def get_alt_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "other_dataset", seed=2345) + + +def get_sharded_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "common_dataset_sharded", seed=1234, documents_per_shard=350) + + +def get_split_test_dataset(): + return _get_test_dataset( + DATASET_CACHE / "common_dataset_split", seed=1234, splits={"training": 1, "validation": 1} + ) + + +def get_split_sharded_test_dataset(): + return _get_test_dataset( + DATASET_CACHE / "common_dataset_split_sharded", + seed=1234, + documents_per_shard=350, + splits={"training": 1, "validation": 1}, + ) + + +def get_test_dataset_with_loss_masking_spans(): + return _get_test_dataset(DATASET_CACHE / "dataset_with_loss_masking_spans", seed=1234, max_loss_masking_spans=5) + + +def get_test_dataset_with_preference_spans(): + return _get_test_dataset(DATASET_CACHE / "dataset_with_preference_spans", seed=1234, has_preference_spans=True) + + +def get_model_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "model_dataset", seed=1234, vocab_size=MODEL_TEST_VOCAB_SIZE) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 42e588911..20a0c7219 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -5,7 +5,6 @@ import os import pathlib -import string from fast_llm.utils import set_global_variables @@ -36,13 +35,11 @@ def set_testing_global_variables(): # TODO: Fixtures TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +TOKENIZER_NAME = "bigcode/santacoder" + DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common_dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" -TEST_VOCAB_SIZE = 8192 -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" +MODEL_DATASET_SHARD_PATH = DATASET_CACHE / "model_dataset/shard_0_0.fast_llm_dataset" + +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7b..956aaea5a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -22,7 +22,7 @@ Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import MODEL_DATASET_SHARD_PATH, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, @@ -234,18 +234,18 @@ def _update_and_add_testing_config( "data": { "datasets": { "training": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "end": 0.969, }, "validation": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "begin": 0.969, "end": 0.999, }, "test": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "begin": 0.999, "end": 1, @@ -279,7 +279,6 @@ def _update_and_add_testing_config( "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", - f"--data-path={MODEL_DATASET_PREFIX}", "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py deleted file mode 100644 index bbfa4b21a..000000000 --- a/tools/concatenate_dataset.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -import logging -import pathlib - -from fast_llm.config import Field, config_class -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.engine.config_utils.runnable import RunnableConfig - -logger = logging.getLogger(__name__) - - -@config_class() -class ConcatenateDatasetConfig(RunnableConfig): - directory: pathlib.Path = Field() - output_name: str = Field(default="fast_llm_dataset.json") - # A lower bound on the number of tokens in a dataset. - # Normally we would like each dataset split to contain at least a few samples, - # i.e. we want num_tokens >= sequence_length * min_split * min_samples_per_split. - # For example with a (999, 1, 0) split , 8K sequence length, we need at least 8M tokens - # for a single validation sample, possibly more if the split is imperfect. - min_tokens: int | None = Field(default=None) - - def run(self): - self.to_logs() - assert self.directory.is_dir() - output_file = self.directory / self.output_name - assert not output_file.exists(), str(output_file) - datasets = [] - - logger.info(f"Loading datasets from {self.directory}") - for path in self.directory.glob("**/*.idx"): - prefix = path.with_suffix("") - logger.info(str(prefix)) - dataset = GPTMemmapDataset("dataset", prefix) - dataset_dict = { - "prefix": str(prefix.relative_to(self.directory)), - "num_documents": len(dataset), - "num_tokens": dataset.num_tokens, - } - if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: - logger.info( - f"Ignoring dataset {dataset_dict['prefix']} with {dataset_dict['num_tokens']:,} tokens" - f" (requiring at least {self.min_tokens:,} tokens)" - ) - else: - datasets.append(dataset_dict) - total_documents = sum(dataset["num_documents"] for dataset in datasets) - total_tokens = sum(dataset["num_tokens"] for dataset in datasets) - logger.info(f"Found {total_documents:,} documents, {total_tokens:,} tokens in {len(datasets)} dataset files") - for dataset in datasets: - dataset["weight"] = dataset["num_tokens"] / total_tokens - logger.info( - f'{dataset["prefix"]}: documents = {dataset["num_documents"]:,}, tokens = {dataset["num_tokens"]:,}, weight = {dataset["weight"]:.6f}' - ) - logger.info(f"Saving merged dataset to {output_file}") - json.dump({"datasets": datasets}, output_file.open("w")) - - -if __name__ == "__main__": - ConcatenateDatasetConfig.parse_and_run()