Skip to content
69 changes: 58 additions & 11 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import tempfile
import traceback
import warnings
from abc import abstractmethod
from contextlib import suppress
from dataclasses import dataclass
Expand Down Expand Up @@ -299,29 +300,65 @@ def _upload_fn(
remove_queue.put([local_filepath])


def _map_items_to_workers_sequentially(num_workers: int, user_items: list[Any]) -> list[list[Any]]:
def _map_items_to_workers_sequentially(
num_workers: int, user_items: list[Any], chunk_size: Optional[int] = None
) -> list[list[Any]]:
"""Map the items to the workers sequentially.

Args:
num_workers: The number of workers to assign items to.
user_items: The list of items to be distributed among workers.
chunk_size: Optional `chunk size` that enforces deterministic,
single-worker-style chunk boundaries. When set, each worker is
assigned only full chunks of this size, and the final worker
receives any remaining items (which may form a partial chunk).


>>> workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
>>> assert workers_user_items == [[0, 1], [2, 3, 4]]
"""
assert isinstance(chunk_size, (int, type(None))), "chunk_size must be an integer or None"

num_nodes = _get_num_nodes()
node_rank = _get_node_rank()
world_size = num_nodes * num_workers
num_items_per_worker = len(user_items) // world_size

num_items_per_worker: list[int] = [num_items_per_worker for _ in range(world_size)]
reminder = len(user_items) % world_size
if chunk_size is not None:
assert chunk_size > 0, "chunk_size must be a positive integer"

for worker_idx in range(len(num_items_per_worker) - 1, -1, -1):
if reminder == 0:
break
num_items_per_worker[worker_idx] += 1
reminder -= 1
# Compute how many full chunks each worker can take
full_chunks = len(user_items) // chunk_size
chunks_per_worker = full_chunks // world_size

if chunks_per_worker == 0 and node_rank == 0:
warnings.warn(
f"chunk_size ({chunk_size}) is too large relative to dataset size ({len(user_items)}) "
f"and world_size ({world_size}). This will result in idle workers. "
f"Consider reducing chunk_size or using fewer workers."
)

# Assign full chunks to all workers except the last
num_items_per_worker = [chunks_per_worker * chunk_size for _ in range(world_size - 1)]

# Last worker receives all remaining items (full chunks + optional tail)
remaining = len(user_items) - sum(num_items_per_worker)
num_items_per_worker.append(remaining)

else:
items_per_worker_count = len(user_items) // world_size

num_items_per_worker: list[int] = [items_per_worker_count for _ in range(world_size)]
reminder = len(user_items) % world_size

for worker_idx in range(len(num_items_per_worker) - 1, -1, -1):
if reminder == 0:
break
num_items_per_worker[worker_idx] += 1
reminder -= 1

num_items_cumsum_per_worker = np.cumsum([0] + num_items_per_worker)

out = []
node_rank = _get_node_rank()
worker_idx_start = node_rank * num_workers
worker_idx_end = (node_rank + 1) * num_workers

Expand Down Expand Up @@ -1080,6 +1117,7 @@ def __init__(
input_dir: Union[str, Dir],
output_dir: Optional[Union[str, Dir]] = None,
num_workers: Optional[int] = None,
align_chunking: bool = False,
num_downloaders: Optional[int] = None,
num_uploaders: Optional[int] = None,
delete_cached_files: bool = True,
Expand All @@ -1102,6 +1140,8 @@ def __init__(
input_dir: The path to where the input data are stored.
output_dir: The path to where the output data are stored.
num_workers: The number of worker threads to use.
align_chunking: Ensures chunk boundaries match the single-worker layout by packing full chunks first
and placing all remaining items in the final worker.
num_downloaders: The number of file downloaders to use.
num_uploaders: The number of file uploaders to use.
delete_cached_files: Whether to delete the cached files.
Expand Down Expand Up @@ -1140,6 +1180,7 @@ def __init__(
self.output_dir = _resolve_dir(output_dir)

self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4)
self.align_chunking = align_chunking
self.num_downloaders = num_downloaders or 2
self.num_uploaders = num_uploaders or 1
self.delete_cached_files = delete_cached_files
Expand Down Expand Up @@ -1231,8 +1272,14 @@ def run(self, data_recipe: DataRecipe) -> None:
num_workers=self.num_workers, user_items=user_items, weights=item_sizes
)
else:
if self.align_chunking and data_recipe.chunk_size is None:
raise ValueError(
"`align_chunking` is set to True, but the `chunk_size` is not defined in the data recipe."
)
workers_user_items = _map_items_to_workers_sequentially(
num_workers=self.num_workers, user_items=user_items
num_workers=self.num_workers,
user_items=user_items,
chunk_size=data_recipe.chunk_size if self.align_chunking else None,
)
else:
assert isinstance(user_items, multiprocessing.queues.Queue)
Expand Down
9 changes: 9 additions & 0 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def optimize(
weights: Optional[list[int]] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
align_chunking: bool = False,
compression: Optional[str] = None,
encryption: Optional[Encryption] = None,
num_workers: Optional[int] = None,
Expand Down Expand Up @@ -428,6 +429,10 @@ def optimize(
weights: Provide an associated weight to each input. This is used to balance work among workers.
chunk_size: The maximum number of elements to hold within a chunk.
chunk_bytes: The maximum number of bytes to hold within a chunk.
align_chunking: Ensures chunk boundaries match the single-worker layout by packing full chunks first
and placing all remaining items in the final worker. Each worker will receive chunks of this size,
except possibly the last worker which may receive a smaller chunk. Note: this will result in uneven
workload distribution among workers, and last worker may receive more data than others.
compression: The compression algorithm to use over the chunks.
encryption: The encryption algorithm to use over the chunks.
num_workers: The number of workers to use during processing
Expand Down Expand Up @@ -489,6 +494,9 @@ def optimize(
if chunk_size is None and chunk_bytes is None:
raise ValueError("Either `chunk_size` or `chunk_bytes` needs to be defined.")

if align_chunking and chunk_size is None:
raise ValueError("When `align_chunking` is set to True, `chunk_size` needs to be defined.")

if not _IS_IN_STUDIO and (machine is not None or num_nodes is not None):
raise ValueError(
"Only https://lightning.ai/ supports multiple nodes or selecting a machine.Create an account to try it out."
Expand Down Expand Up @@ -555,6 +563,7 @@ def optimize(
input_dir=resolved_dir,
output_dir=_output_dir,
num_workers=num_workers or _get_default_num_workers(),
align_chunking=align_chunking,
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
num_uploaders=num_uploaders,
Expand Down
51 changes: 51 additions & 0 deletions tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,36 @@ def test_map_items_to_workers_sequentially(monkeypatch):
assert workers_user_items == [[24, 25], [26, 27], [28, 29], [30, 31]]


def test_map_items_to_workers_sequentially_align_chunking(monkeypatch):
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)), chunk_size=2)
assert workers_user_items == [list(range(5))]
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)), chunk_size=2)
assert workers_user_items == [[0, 1], [2, 3, 4]]
workers_user_items = _map_items_to_workers_sequentially(2, list(range(6)), chunk_size=2)
assert workers_user_items == [[0, 1], [2, 3, 4, 5]]

monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)), chunk_size=2)
assert workers_user_items == [[0, 1]]

# 2 nodes, 2 workers per node, chunk_size=2.
# Total items = 5 => only the final worker should receive them,
# because no worker except the last can form even one full chunk. (5/ (2*2*2) = 0.625 ~ 0)
with pytest.warns(UserWarning, match="Consider reducing chunk_size or using fewer workers"):
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)), chunk_size=2)
assert workers_user_items == [[], []]

monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)), chunk_size=2)
assert workers_user_items == [[2, 3, 4]]

# On node 1 (rank 1), last worker should receive all items.
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)), chunk_size=2)
assert workers_user_items == [[], [0, 1, 2, 3, 4]]


def test_fake_queue():
q = FakeQueue()
index = [1, 2]
Expand Down Expand Up @@ -400,6 +430,16 @@ def prepare_item(self, item):
return item


class DummyDataChunkRecipe(DataChunkRecipe):
is_generator = False

def prepare_structure(self, input_dir: str) -> list[Any]:
return []

def prepare_item(self, item):
return item


@pytest.mark.parametrize("delete_cached_files", [True])
@pytest.mark.parametrize("fast_dev_run", [10])
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
Expand Down Expand Up @@ -477,6 +517,17 @@ def test_data_processsor(fast_dev_run, delete_cached_files, tmpdir, monkeypatch)
assert len(files) == expected


def test_data_processor_align_chunking_requires_chunk_size(tmpdir):
output_dir = str(tmpdir / "output_dir")
data_processor = DataProcessor(input_dir=Dir(), output_dir=output_dir, num_workers=1, align_chunking=True)
with pytest.raises(ValueError, match="`chunk_size` is not defined in the data recipe"):
data_processor.run(
DummyDataChunkRecipe(
chunk_bytes="10MB" # chunk_size is not defined here to trigger the error
)
)


class TestDataProcessor(DataProcessor):
def _broadcast_object(self, obj: Any) -> Any:
return obj
Expand Down
56 changes: 56 additions & 0 deletions tests/processing/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import io
import math
import os
import random
import shutil
Expand Down Expand Up @@ -123,6 +124,61 @@ def random_image(index):
return {"image": fake_img, "class": index}


@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
def test_optimize_align_chunking_requires_chunk_size(tmp_path):
output_dir = tmp_path / "output_requires_chunk_size"

with pytest.raises(ValueError, match="`chunk_size` needs to be defined"):
optimize(
fn=compress,
inputs=list(range(7 * 64)),
chunk_bytes="1MB",
output_dir=str(output_dir),
num_workers=1,
align_chunking=True,
)


@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
@pytest.mark.parametrize("num_workers", [1, 2])
@pytest.mark.parametrize("chunk_size", [16, 32, 64])
def test_optimize_align_chunking_creates_expected_chunks(tmp_path, chunk_size, num_workers):
output_dir = tmp_path / f"output_workers_{num_workers}"

inputs = list(range(7 * 64))

optimize(
fn=compress,
inputs=inputs,
chunk_size=chunk_size,
output_dir=str(output_dir),
num_workers=num_workers,
align_chunking=True,
)

assert output_dir.exists()

actual_files = set(os.listdir(output_dir))

total_items = len(inputs)
items_per_worker = total_items / num_workers
chunks_per_worker = items_per_worker / chunk_size

# each worker should create `math.floor(chunks_per_worker)` chunks,
# except the last worker which will create the chunk with remaining items `math.ceil(chunks_per_worker)`
expected_chunks_by_worker = {
worker_id: (math.floor(chunks_per_worker) if worker_id < num_workers - 1 else math.ceil(chunks_per_worker))
for worker_id in range(num_workers)
}

expected_chunk_files = {
f"chunk-{worker_id}-{i}.bin" for worker_id, indices in expected_chunks_by_worker.items() for i in range(indices)
}
expected_files = expected_chunk_files | {"index.json"}

assert actual_files == expected_files


@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
def test_optimize_append_overwrite(tmpdir):
output_dir = str(tmpdir / "output_dir")
Expand Down
33 changes: 33 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import tensor

from litdata.constants import _VIZ_TRACKER_AVAILABLE
from litdata.processing.functions import optimize
from litdata.streaming import (
Cache,
CombinedStreamingDataset,
Expand Down Expand Up @@ -496,3 +497,35 @@ def test_dataloader_dataset_transform_inheritance(tmpdir, shuffle):
# Verify that the transform is applied correctly
for i, item in enumerate(complete_data):
assert item == i * 2, f"Expected {i * 2}, got {item}"


def getter(index: int):
return index


@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
@pytest.mark.parametrize("num_workers", [1, 2])
def test_dataloader_with_align_chunking(tmp_path, num_workers):
output_dir = tmp_path / f"output_workers_{num_workers}"

optimize(
fn=getter,
inputs=list(range(7 * 64)),
chunk_size=64,
output_dir=str(output_dir),
num_workers=num_workers,
align_chunking=True,
)

# Ensure batches contain elements from the same chunk when using align_chunking
dataset = StreamingDataset(str(output_dir), shuffle=True)

# make sure batch_size of dataloader is equal to chunk_size used during optimize
dataloader = StreamingDataLoader(dataset, batch_size=64, num_workers=num_workers, shuffle=True)

for i, batch in enumerate(dataloader):
min_element_in_batch = torch.min(batch).item()
max_element_in_batch = torch.max(batch).item()
assert max_element_in_batch - min_element_in_batch < 64, (
f"Batch {i} contains elements from multiple chunks: min {min_element_in_batch}, max {max_element_in_batch}"
)
Loading