Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,58 @@

SALT_LANGUAGE_TOKENS_WHISPER = {
# Exact/close mapping
'eng': 50259,
'swa': 50318,
"eng": 50259,
"fra": 50265,
"swa": 50318,
"sna": 50324,
"yor": 50325,
"som": 50326,
"afr": 50327,
"amh": 50334,
"mlg": 50349,
"lin": 50353,
"hau": 50354,
# Overwrite unused language tokens
'ach': 50357,
'lgg': 50356,
'lug': 50355,
'nyn': 50354,
'teo': 50353,
'xog': 50352,
'ttj': 50351,
'kin': 50350,
'myx': 50349,
'kik': 50348,
"ach": 50357,
"aka": 50356,
"bam": 50355,
"bem": 50352,
"ber": 50351,
"cgg": 50350,
"dag": 50348,
"dga": 50347,
"ewe": 50346,
"ful": 50345,
"ibo": 50344,
"kab": 50343,
"kau": 50342,
"kik": 50341,
"kin": 50340,
"kln": 50339,
"koo": 50338,
"kpo": 50337,
"led": 50336,
"lgg": 50335,
"lth": 50333,
"lug": 50332,
"luo": 50331,
"luy": 50330,
"myx": 50329,
"nbl": 50328,
"nya": 50323,
"nyn": 50322,
"orm": 50321,
"pcm": 50320,
"ruc": 50319,
"rwm": 50317,
"sot": 50316,
"teo": 50315,
"tsn": 50314,
"ttj": 50313,
"wol": 50312,
"xho": 50311,
"xog": 50310,
"zul": 50309,
}

SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION = {
Expand Down
164 changes: 125 additions & 39 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,27 @@ def _add_speaker_id_studio_if_not_present(sample):
return sample


def _load_single_huggingface_dataset(load_dataset_params):
ds = datasets.load_dataset(**load_dataset_params)
def _load_single_dataset(load_dataset_params, gcs_key_path=None, max_examples=None):
# if path contains gcs://, then load with google_default token
if "path" in load_dataset_params and "gcs://" in load_dataset_params["path"]:
if gcs_key_path is None:
raise ValueError(
"gcs_key_path must be provided when loading from Google Cloud Storage."
)
print("downloading datasets from: ", load_dataset_params["path"])
ds = datasets.load_dataset(
"parquet",
data_files=load_dataset_params["path"],
storage_options={"token": gcs_key_path},
)
ds= ds.cast_column("audio", datasets.Audio())
print("download completed: ", load_dataset_params["path"])
# otherwise load from hugging face with the provided params
else:
ds = datasets.load_dataset(**load_dataset_params)

if isinstance(ds, datasets.DatasetDict):
split_names = list(ds.data.keys())
split_names = list(ds.keys())
# If the split wasn't specified, but there's only one, then just go
# ahead and load that one.
if len(split_names) == 1:
Expand All @@ -88,6 +105,9 @@ def _load_single_huggingface_dataset(load_dataset_params):
f"{load_dataset_params}. Splits found: {list(ds.keys())}."
)

if max_examples is not None:
ds = ds.select(range(min(len(ds), max_examples)))

remap_names = {
"audio_language": "language",
}
Expand Down Expand Up @@ -206,30 +226,42 @@ def _dataset_id_from_config(load_params):
return "_".join(tag)


def _load_huggingface_datasets(config):
"""Retrieve all specified HuggingFace datasets and return as a list."""
def _load_datasets(config):
"""Retrieve all specified datasets and return as a list."""
loaded_datasets = []
if "huggingface_load" not in config:
if "datasets" not in config and "huggingface_load" not in config:
raise ValueError(
"There should be a `huggingface_load` entry in the dataset config, "
"There should be a `datasets` or `huggingface_load` entry in the dataset config, "
f"specifying which datasets to download. Got: {config}."
)

load_list = config["huggingface_load"]
load_list = config["datasets"] if "datasets" in config else config["huggingface_load"]
gcs_key_path = config.get("gcs_key_path")
max_examples = config.get("max_examples_per_dataset")

# Optionally pre-download everything at once
if config.get("download_datasets_in_parallel"):
# Disable progress bars to prevent tqdm thread-safety issues (e.g. `_lock` AttributeError)
datasets.disable_progress_bar()
# Explicitly initialize the tqdm lock in the main thread to prevent concurrent `del tqdm_class._lock` bugs
import tqdm
tqdm.tqdm.get_lock()

threads = []
for l in _ensure_list(load_list):
if "join" in l:
for i in (0, 1):
thread = threading.Thread(
target=_load_single_huggingface_dataset, args=(l["join"][i],)
target=_load_single_dataset,
args=(l["join"][i],),
kwargs={"gcs_key_path": gcs_key_path},
)
threads.append(thread)
else:
thread = threading.Thread(
target=_load_single_huggingface_dataset, args=(l,)
target=_load_single_dataset,
args=(l,),
kwargs={"gcs_key_path": gcs_key_path},
)
threads.append(thread)
thread.start()
Expand All @@ -243,8 +275,12 @@ def _load_huggingface_datasets(config):
"If a dataset join is specified, then there should be a "
f"list of exactly two datasets to be joined. Got: {l}."
)
left = _load_single_huggingface_dataset(l["join"][0])
right = _load_single_huggingface_dataset(l["join"][1])
left = _load_single_dataset(
l["join"][0], gcs_key_path=gcs_key_path, max_examples=max_examples
)
right = _load_single_dataset(
l["join"][1], gcs_key_path=gcs_key_path, max_examples=max_examples
)

generator_function = lambda: _combine_datasets_generator(left, right)
ds = datasets.IterableDataset.from_generator(generator_function)
Expand All @@ -254,7 +290,9 @@ def _load_huggingface_datasets(config):
+ _dataset_id_from_config(l["join"][1])
)
else:
ds = _load_single_huggingface_dataset(l)
ds = _load_single_dataset(
l, gcs_key_path=gcs_key_path, max_examples=max_examples
)
dataset_id = _dataset_id_from_config(l)
loaded_datasets.append([ds, dataset_id])

Expand Down Expand Up @@ -443,13 +481,12 @@ def _matching_pairs(row, config):
yield example


def _create_generator(config, verbose=False):
def _create_generator(loaded_datasets, config, verbose=False):
"""Make a generator that yields examples according to dataset spec."""
huggingface_datasets = _load_huggingface_datasets(config)

if verbose:
total_row_count = 0
for ds, id in huggingface_datasets:
for ds, id in loaded_datasets:
row_count = len(ds)
total_row_count += row_count
print(f"{id}: {row_count} rows")
Expand All @@ -459,29 +496,40 @@ def _yield_matches(batch, config, dataset_id):
keys = list(batch.keys())
rows = [{k: batch[k][i] for k in keys} for i in range(len(batch[keys[0]]))]
for row in rows:
# The audio SALT datasets are in a slightly different format
# to the translation data, each row having a 'text' and 'language'
# field.
if "audio" in row and "text" in row:
row[row["language"] + "_text"] = row["text"]
del row["text"]
for match in _matching_pairs(row | {"origin_dataset": dataset_id}, config):
yield match
if config.get("skip_matching_asr"):
audio_array, sample_rate = _get_audio_from_row(row)
example = {
"source": audio_array,
"source.sample_rate": sample_rate,
"source.language": row.get("language"),
"target": row.get("text"),
"target.language": row.get("language"),
}
yield example
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. A further improvement for later, in case it's an ASR/audio dataset and the format already matches, is not to use a generator at all - we just load the huggingface datasets and concatenate them. That should reduce CPU bottleneck and could improve GPU utilisation.

else:
# The audio SALT datasets are in a slightly different format
# to the translation data, each row having a 'text' and 'language'
# field.
if "audio" in row and "text" in row:
row[row["language"] + "_text"] = row["text"]
del row["text"]
for match in _matching_pairs(row | {"origin_dataset": dataset_id}, config):
yield match

# PyArrow data should be read in batches for speed.
PYARROW_BATCH_SIZE = 10
num_workers = config.get("num_workers", 4)

if config.get("shuffle") and len(huggingface_datasets) > 1:
if config.get("shuffle") and len(loaded_datasets) > 1:
# If there are multiple datasets concatenated and 'shuffle' is
# specified, then we want to randomly interleave them.
iterators = [
d[0].iter(batch_size=PYARROW_BATCH_SIZE) for d in huggingface_datasets
d[0].iter(batch_size=PYARROW_BATCH_SIZE) for d in loaded_datasets
]
iterator_order = []
for i in range(len(huggingface_datasets)):
for i in range(len(loaded_datasets)):
num_batches = math.ceil(
len(huggingface_datasets[i][0]) / PYARROW_BATCH_SIZE
len(loaded_datasets[i][0]) / PYARROW_BATCH_SIZE
)
iterator_order.extend([i] * num_batches)
permutation = np.random.permutation(len(iterator_order))
Expand All @@ -502,13 +550,13 @@ def process_batch(args):
batch = next(iterators[iterator_id])
future = executor.submit(
process_batch,
(batch, config, huggingface_datasets[iterator_id][1]),
(batch, config, loaded_datasets[iterator_id][1]),
)
future_queue.put((future, iterator_id))
except StopIteration:
break
except Exception as e:
print("Error reading from " + huggingface_datasets[iterator_id][1])
print("Error reading from " + loaded_datasets[iterator_id][1])
raise
while not future_queue.empty():
future, iterator_id = future_queue.get()
Expand All @@ -520,13 +568,13 @@ def process_batch(args):
batch = next(iterators[iterator_id])
next_future = executor.submit(
process_batch,
(batch, config, huggingface_datasets[iterator_id][1]),
(batch, config, loaded_datasets[iterator_id][1]),
)
future_queue.put((next_future, iterator_id))
except StopIteration:
continue
except Exception as e:
print("Error reading from " + huggingface_datasets[iterator_id][1])
print("Error reading from " + loaded_datasets[iterator_id][1])
raise
elif config.get("shuffle"):
# Single dataset, shuffle is True: parallelize batches, order doesn't matter
Expand All @@ -536,7 +584,7 @@ def process_batch(args):

with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for ds, dataset_id in huggingface_datasets:
for ds, dataset_id in loaded_datasets:
for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE):
futures.append(
executor.submit(process_batch, (batch, config, dataset_id))
Expand All @@ -545,10 +593,43 @@ def process_batch(args):
for match in future.result():
yield match
else:
# No shuffle: preserve strict order, process sequentially
for ds, dataset_id in huggingface_datasets:
for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE):
yield from _yield_matches(batch, config, dataset_id)
# No shuffle: preserve strict order, parallelize processing
def process_batch(args):
batch, config, dataset_id = args
return list(_yield_matches(batch, config, dataset_id))

prefetch_limit = 4 * num_workers
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
future_queue = queue.Queue()

def batch_generator():
for ds, dataset_id in loaded_datasets:
for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE):
yield batch, config, dataset_id

batch_iter = batch_generator()

# Prefill the queue
for _ in range(prefetch_limit):
try:
args = next(batch_iter)
future = executor.submit(process_batch, args)
future_queue.put(future)
except StopIteration:
break

while not future_queue.empty():
future = future_queue.get()
for match in future.result():
yield match

# Submit the next batch if available
try:
args = next(batch_iter)
next_future = executor.submit(process_batch, args)
future_queue.put(next_future)
except StopIteration:
continue


def _compose(functions):
Expand Down Expand Up @@ -627,7 +708,7 @@ def create(config, verbose=False):
Create a dataset from the given configuration.

Args:
huggingface_load : Dict containing keyword arguments to HuggingFace
datasets : Dict containing keyword arguments to
datasets.load_dataset(), or a list of dicts to load multiple
datasets. The dataset should be in SALT format, as per
hf.co/datasets/sunbird/salt. Common Voice is also supported, if
Expand Down Expand Up @@ -662,8 +743,12 @@ def create(config, verbose=False):
f"{language}. Change to [{language}] to make it a list."
)

generator_function = lambda: _create_generator(config, verbose=verbose)
loaded_datasets = _load_datasets(config)
total_row_count = sum(len(ds) for ds, id in loaded_datasets)

generator_function = lambda: _create_generator(loaded_datasets, config, verbose=verbose)
ds = datasets.IterableDataset.from_generator(generator_function)
ds.approx_row_count = total_row_count

# The individual datasets are already shuffled as needed, but do a little
# more so that consecutive samples are from different batches.
Expand All @@ -678,4 +763,5 @@ def create(config, verbose=False):
ds = ds.select_columns(
["source", "target", "source.language", "target.language"]
)
ds.approx_row_count = total_row_count
return ds
Loading