diff --git a/constants.py b/constants.py index c7c271d..75da33d 100644 --- a/constants.py +++ b/constants.py @@ -80,19 +80,58 @@ SALT_LANGUAGE_TOKENS_WHISPER = { # Exact/close mapping - 'eng': 50259, - 'swa': 50318, + "eng": 50259, + "fra": 50265, + "swa": 50318, + "sna": 50324, + "yor": 50325, + "som": 50326, + "afr": 50327, + "amh": 50334, + "mlg": 50349, + "lin": 50353, + "hau": 50354, # Overwrite unused language tokens - 'ach': 50357, - 'lgg': 50356, - 'lug': 50355, - 'nyn': 50354, - 'teo': 50353, - 'xog': 50352, - 'ttj': 50351, - 'kin': 50350, - 'myx': 50349, - 'kik': 50348, + "ach": 50357, + "aka": 50356, + "bam": 50355, + "bem": 50352, + "ber": 50351, + "cgg": 50350, + "dag": 50348, + "dga": 50347, + "ewe": 50346, + "ful": 50345, + "ibo": 50344, + "kab": 50343, + "kau": 50342, + "kik": 50341, + "kin": 50340, + "kln": 50339, + "koo": 50338, + "kpo": 50337, + "led": 50336, + "lgg": 50335, + "lth": 50333, + "lug": 50332, + "luo": 50331, + "luy": 50330, + "myx": 50329, + "nbl": 50328, + "nya": 50323, + "nyn": 50322, + "orm": 50321, + "pcm": 50320, + "ruc": 50319, + "rwm": 50317, + "sot": 50316, + "teo": 50315, + "tsn": 50314, + "ttj": 50313, + "wol": 50312, + "xho": 50311, + "xog": 50310, + "zul": 50309, } SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION = { diff --git a/dataset.py b/dataset.py index a80d975..e028f05 100644 --- a/dataset.py +++ b/dataset.py @@ -73,10 +73,27 @@ def _add_speaker_id_studio_if_not_present(sample): return sample -def _load_single_huggingface_dataset(load_dataset_params): - ds = datasets.load_dataset(**load_dataset_params) +def _load_single_dataset(load_dataset_params, gcs_key_path=None, max_examples=None): + # if path contains gcs://, then load with google_default token + if "path" in load_dataset_params and "gcs://" in load_dataset_params["path"]: + if gcs_key_path is None: + raise ValueError( + "gcs_key_path must be provided when loading from Google Cloud Storage." + ) + print("downloading datasets from: ", load_dataset_params["path"]) + ds = datasets.load_dataset( + "parquet", + data_files=load_dataset_params["path"], + storage_options={"token": gcs_key_path}, + ) + ds= ds.cast_column("audio", datasets.Audio()) + print("download completed: ", load_dataset_params["path"]) + # otherwise load from hugging face with the provided params + else: + ds = datasets.load_dataset(**load_dataset_params) + if isinstance(ds, datasets.DatasetDict): - split_names = list(ds.data.keys()) + split_names = list(ds.keys()) # If the split wasn't specified, but there's only one, then just go # ahead and load that one. if len(split_names) == 1: @@ -88,6 +105,9 @@ def _load_single_huggingface_dataset(load_dataset_params): f"{load_dataset_params}. Splits found: {list(ds.keys())}." ) + if max_examples is not None: + ds = ds.select(range(min(len(ds), max_examples))) + remap_names = { "audio_language": "language", } @@ -206,30 +226,42 @@ def _dataset_id_from_config(load_params): return "_".join(tag) -def _load_huggingface_datasets(config): - """Retrieve all specified HuggingFace datasets and return as a list.""" +def _load_datasets(config): + """Retrieve all specified datasets and return as a list.""" loaded_datasets = [] - if "huggingface_load" not in config: + if "datasets" not in config and "huggingface_load" not in config: raise ValueError( - "There should be a `huggingface_load` entry in the dataset config, " + "There should be a `datasets` or `huggingface_load` entry in the dataset config, " f"specifying which datasets to download. Got: {config}." ) - load_list = config["huggingface_load"] + load_list = config["datasets"] if "datasets" in config else config["huggingface_load"] + gcs_key_path = config.get("gcs_key_path") + max_examples = config.get("max_examples_per_dataset") # Optionally pre-download everything at once if config.get("download_datasets_in_parallel"): + # Disable progress bars to prevent tqdm thread-safety issues (e.g. `_lock` AttributeError) + datasets.disable_progress_bar() + # Explicitly initialize the tqdm lock in the main thread to prevent concurrent `del tqdm_class._lock` bugs + import tqdm + tqdm.tqdm.get_lock() + threads = [] for l in _ensure_list(load_list): if "join" in l: for i in (0, 1): thread = threading.Thread( - target=_load_single_huggingface_dataset, args=(l["join"][i],) + target=_load_single_dataset, + args=(l["join"][i],), + kwargs={"gcs_key_path": gcs_key_path}, ) threads.append(thread) else: thread = threading.Thread( - target=_load_single_huggingface_dataset, args=(l,) + target=_load_single_dataset, + args=(l,), + kwargs={"gcs_key_path": gcs_key_path}, ) threads.append(thread) thread.start() @@ -243,8 +275,12 @@ def _load_huggingface_datasets(config): "If a dataset join is specified, then there should be a " f"list of exactly two datasets to be joined. Got: {l}." ) - left = _load_single_huggingface_dataset(l["join"][0]) - right = _load_single_huggingface_dataset(l["join"][1]) + left = _load_single_dataset( + l["join"][0], gcs_key_path=gcs_key_path, max_examples=max_examples + ) + right = _load_single_dataset( + l["join"][1], gcs_key_path=gcs_key_path, max_examples=max_examples + ) generator_function = lambda: _combine_datasets_generator(left, right) ds = datasets.IterableDataset.from_generator(generator_function) @@ -254,7 +290,9 @@ def _load_huggingface_datasets(config): + _dataset_id_from_config(l["join"][1]) ) else: - ds = _load_single_huggingface_dataset(l) + ds = _load_single_dataset( + l, gcs_key_path=gcs_key_path, max_examples=max_examples + ) dataset_id = _dataset_id_from_config(l) loaded_datasets.append([ds, dataset_id]) @@ -443,13 +481,12 @@ def _matching_pairs(row, config): yield example -def _create_generator(config, verbose=False): +def _create_generator(loaded_datasets, config, verbose=False): """Make a generator that yields examples according to dataset spec.""" - huggingface_datasets = _load_huggingface_datasets(config) if verbose: total_row_count = 0 - for ds, id in huggingface_datasets: + for ds, id in loaded_datasets: row_count = len(ds) total_row_count += row_count print(f"{id}: {row_count} rows") @@ -459,29 +496,40 @@ def _yield_matches(batch, config, dataset_id): keys = list(batch.keys()) rows = [{k: batch[k][i] for k in keys} for i in range(len(batch[keys[0]]))] for row in rows: - # The audio SALT datasets are in a slightly different format - # to the translation data, each row having a 'text' and 'language' - # field. - if "audio" in row and "text" in row: - row[row["language"] + "_text"] = row["text"] - del row["text"] - for match in _matching_pairs(row | {"origin_dataset": dataset_id}, config): - yield match + if config.get("skip_matching_asr"): + audio_array, sample_rate = _get_audio_from_row(row) + example = { + "source": audio_array, + "source.sample_rate": sample_rate, + "source.language": row.get("language"), + "target": row.get("text"), + "target.language": row.get("language"), + } + yield example + else: + # The audio SALT datasets are in a slightly different format + # to the translation data, each row having a 'text' and 'language' + # field. + if "audio" in row and "text" in row: + row[row["language"] + "_text"] = row["text"] + del row["text"] + for match in _matching_pairs(row | {"origin_dataset": dataset_id}, config): + yield match # PyArrow data should be read in batches for speed. PYARROW_BATCH_SIZE = 10 num_workers = config.get("num_workers", 4) - if config.get("shuffle") and len(huggingface_datasets) > 1: + if config.get("shuffle") and len(loaded_datasets) > 1: # If there are multiple datasets concatenated and 'shuffle' is # specified, then we want to randomly interleave them. iterators = [ - d[0].iter(batch_size=PYARROW_BATCH_SIZE) for d in huggingface_datasets + d[0].iter(batch_size=PYARROW_BATCH_SIZE) for d in loaded_datasets ] iterator_order = [] - for i in range(len(huggingface_datasets)): + for i in range(len(loaded_datasets)): num_batches = math.ceil( - len(huggingface_datasets[i][0]) / PYARROW_BATCH_SIZE + len(loaded_datasets[i][0]) / PYARROW_BATCH_SIZE ) iterator_order.extend([i] * num_batches) permutation = np.random.permutation(len(iterator_order)) @@ -502,13 +550,13 @@ def process_batch(args): batch = next(iterators[iterator_id]) future = executor.submit( process_batch, - (batch, config, huggingface_datasets[iterator_id][1]), + (batch, config, loaded_datasets[iterator_id][1]), ) future_queue.put((future, iterator_id)) except StopIteration: break except Exception as e: - print("Error reading from " + huggingface_datasets[iterator_id][1]) + print("Error reading from " + loaded_datasets[iterator_id][1]) raise while not future_queue.empty(): future, iterator_id = future_queue.get() @@ -520,13 +568,13 @@ def process_batch(args): batch = next(iterators[iterator_id]) next_future = executor.submit( process_batch, - (batch, config, huggingface_datasets[iterator_id][1]), + (batch, config, loaded_datasets[iterator_id][1]), ) future_queue.put((next_future, iterator_id)) except StopIteration: continue except Exception as e: - print("Error reading from " + huggingface_datasets[iterator_id][1]) + print("Error reading from " + loaded_datasets[iterator_id][1]) raise elif config.get("shuffle"): # Single dataset, shuffle is True: parallelize batches, order doesn't matter @@ -536,7 +584,7 @@ def process_batch(args): with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [] - for ds, dataset_id in huggingface_datasets: + for ds, dataset_id in loaded_datasets: for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE): futures.append( executor.submit(process_batch, (batch, config, dataset_id)) @@ -545,10 +593,43 @@ def process_batch(args): for match in future.result(): yield match else: - # No shuffle: preserve strict order, process sequentially - for ds, dataset_id in huggingface_datasets: - for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE): - yield from _yield_matches(batch, config, dataset_id) + # No shuffle: preserve strict order, parallelize processing + def process_batch(args): + batch, config, dataset_id = args + return list(_yield_matches(batch, config, dataset_id)) + + prefetch_limit = 4 * num_workers + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + future_queue = queue.Queue() + + def batch_generator(): + for ds, dataset_id in loaded_datasets: + for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE): + yield batch, config, dataset_id + + batch_iter = batch_generator() + + # Prefill the queue + for _ in range(prefetch_limit): + try: + args = next(batch_iter) + future = executor.submit(process_batch, args) + future_queue.put(future) + except StopIteration: + break + + while not future_queue.empty(): + future = future_queue.get() + for match in future.result(): + yield match + + # Submit the next batch if available + try: + args = next(batch_iter) + next_future = executor.submit(process_batch, args) + future_queue.put(next_future) + except StopIteration: + continue def _compose(functions): @@ -627,7 +708,7 @@ def create(config, verbose=False): Create a dataset from the given configuration. Args: - huggingface_load : Dict containing keyword arguments to HuggingFace + datasets : Dict containing keyword arguments to datasets.load_dataset(), or a list of dicts to load multiple datasets. The dataset should be in SALT format, as per hf.co/datasets/sunbird/salt. Common Voice is also supported, if @@ -662,8 +743,12 @@ def create(config, verbose=False): f"{language}. Change to [{language}] to make it a list." ) - generator_function = lambda: _create_generator(config, verbose=verbose) + loaded_datasets = _load_datasets(config) + total_row_count = sum(len(ds) for ds, id in loaded_datasets) + + generator_function = lambda: _create_generator(loaded_datasets, config, verbose=verbose) ds = datasets.IterableDataset.from_generator(generator_function) + ds.approx_row_count = total_row_count # The individual datasets are already shuffled as needed, but do a little # more so that consecutive samples are from different batches. @@ -678,4 +763,5 @@ def create(config, verbose=False): ds = ds.select_columns( ["source", "target", "source.language", "target.language"] ) + ds.approx_row_count = total_row_count return ds diff --git a/metrics.py b/metrics.py index 0489a1f..f3d74b7 100644 --- a/metrics.py +++ b/metrics.py @@ -4,6 +4,7 @@ import pandas as pd import functools import numbers +import tqdm def _normalise(string_list, @@ -82,7 +83,7 @@ def _round_if_float(f, p): f'True label: "{decoded_labels[i]}"') subsets = {} - for i in range(len(decoded_predictions)): + for i in tqdm.tqdm(range(len(decoded_predictions)), desc="Grouping language subsets"): if speech_processor: # For speech metrics, such as WER, we evaluate for separate target # languages. @@ -146,9 +147,25 @@ def multilingual_eval_fn(eval_dataset, If `speech_processor` is defined, then it is used to decode the predictions. Otherwise `tokenizer` is used (e.g. for translation tasks).''' - df = pd.DataFrame(eval_dataset) - source_language = list(df['source.language']) - target_language = list(df['target.language']) + try: + from . import dataset as salt_dataset + except ImportError: + import dataset as salt_dataset + + original_get_audio = salt_dataset._get_audio_from_row + + try: + # Mock audio extraction to bypass heavy CPU decoding + salt_dataset._get_audio_from_row = lambda row: (np.zeros(1), 16000) + + source_language = [] + target_language = [] + for item in tqdm.tqdm(eval_dataset, desc="Extracting source and target languages"): + source_language.append(item['source.language']) + target_language.append(item['target.language']) + finally: + # Restore the original method so Trainer evaluation loop works correctly + salt_dataset._get_audio_from_row = original_get_audio metric_names = [] for m in metrics: diff --git a/notebooks/training/whisper_finetuning_gcs.ipynb b/notebooks/training/whisper_finetuning_gcs.ipynb new file mode 100644 index 0000000..3acb613 --- /dev/null +++ b/notebooks/training/whisper_finetuning_gcs.ipynb @@ -0,0 +1,1035 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Working on RTX 6000Ada 48GB (per-device batch size 2) and H100 80GB (per-device batch size 16)\n", + "\n", + "Before running cells below, please open a terminal and run the following command under the repo's root directory to set up the environment:\n", + "```\n", + "bash whisper_training_setup.sh\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4wWOTwrVCqnF" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (7.2.1)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: pynvml in /usr/local/lib/python3.12/dist-packages (13.0.1)\n", + "Requirement already satisfied: nvidia-ml-py>=12.0.0 in /usr/local/lib/python3.12/dist-packages (from pynvml) (13.590.48)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Enter the MLFLOW_TRACKING_USERNAME: ········\n", + "Enter the MLFLOW_TRACKING_PASSWORD: ········\n" + ] + } + ], + "source": [ + "use_wandb = False\n", + "use_mlflow = True\n", + "\n", + "import importlib.metadata\n", + "installed = [\n", + " dist.metadata['Name']\n", + " for dist in importlib.metadata.distributions()\n", + "]\n", + "\n", + "if use_wandb:\n", + " !pip install wandb\n", + " import wandb\n", + " %set_env WANDB_LOG_MODEL=True\n", + " %set_env WANDB_WATCH=all\n", + " %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb\n", + " wandb.login()\n", + "\n", + "if use_mlflow:\n", + " if 'mlflow' not in installed:\n", + " !pip install mlflow\n", + " ## requirements to log system/GPU metrics in mlflow\n", + " !pip install psutil\n", + " !pip install pynvml\n", + " import os\n", + " from getpass import getpass\n", + " import mlflow\n", + " import mlflow.pytorch\n", + " from mlflow import MlflowClient\n", + "\n", + " mlflow.enable_system_metrics_logging()\n", + "\n", + " # Set MLflow tracking credentials\n", + " MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')\n", + " os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME\n", + "\n", + " MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')\n", + " os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD\n", + "\n", + " # Set the MLflow tracking URI\n", + " mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')\n", + " mlflow.system_metrics.enable_system_metrics_logging()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "6KKzjwqnb6Dh" + }, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "root_dir = Path.cwd().parent.parent.parent\n", + "sys.path.append(str(root_dir))\n", + "\n", + "import torch\n", + "import transformers\n", + "from dataclasses import dataclass, field\n", + "from typing import Union, List, Dict, Any\n", + "import string\n", + "import os\n", + "import json\n", + "import datasets\n", + "import numpy as np\n", + "import yaml\n", + "import evaluate\n", + "import salt.dataset\n", + "import salt.metrics\n", + "import salt.constants\n", + "from salt.utils import DataCollatorCTCWithPadding as dcwp\n", + "import huggingface_hub\n", + "import peft\n", + "import pandas as pd\n", + "import gcsfs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IlbdSLfKNYfF", + "scrolled": true + }, + "outputs": [], + "source": [ + "huggingface_hub.notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List all the available datasets in google cloud buckets for training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "gcs = gcsfs.GCSFileSystem(project='sb-gcp-project-01', token='gcs-data-viewer-key.json')\n", + "bucket_path = \"sunflower-data/speech\"\n", + "base_gcs = f\"gcs://{bucket_path}\"\n", + "all_datasets = gcs.ls(path=bucket_path, detail=False)\n", + "all_datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the yaml file that contains specific datasets and processing methods for training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('configs/whisper_finetuning_gcs.yaml', 'r') as file:\n", + " config = yaml.safe_load(file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # show the datasets not included for training (sometimes just empty folders)\n", + "# train_datasets = set([ds['path'].replace('/train/*.parquet', '').replace('gcs://', '') \n", + "# for ds in config['train']['datasets']])\n", + "# set(all_datasets) - train_datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "i1Vd4A4UIwLk" + }, + "outputs": [], + "source": [ + "train_ds = salt.dataset.create(config['train'], verbose=True)\n", + "valid_ds = salt.dataset.create(config['validation'], verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspect the processed datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Example 0 ---\n", + "source: [ 2.0106706e-06 3.2367443e-06 -2.7039619e-06 ... -1.4706944e-06\n", + " 2.6853128e-05 1.4958448e-05]\n", + "target: Ikiraro cya gari ya moshi kimaze kubakwa hejuru yuruzi. Ejo bundi nasanze igitabo cyanditswe na data. Ugomba guhindura uruhushya rwawe rwa kera kugirango ubone urundi rushya. Yemeza mushiki wa Denis, Ndayishimiye yakomeje avuga ko hari ingaruka nyinshi. Maze bigafasha mwarimu uburyo yatangamo isomo rye neza.\n", + "Languages: lug -> lug\n", + "\n" + ] + } + ], + "source": [ + "for i, example in enumerate(train_ds.take(1)):\n", + " print(f\"--- Example {i} ---\")\n", + " \n", + " # Check if 'source' or 'target' contains large audio arrays\n", + " # If they are dicts with 'array' keys, print just the shape/metadata\n", + " for key in ['source', 'target']:\n", + " data = example.get(key)\n", + " if isinstance(data, dict) and 'array' in data:\n", + " # Avoid printing the massive raw audio array\n", + " print(f\"{key}: [Audio Data] Sample Rate: {data.get('sampling_rate')}, Shape: {data['array'].shape}\")\n", + " else:\n", + " print(f\"{key}: {data}\")\n", + " \n", + " print(f\"Languages: {example.get('source.language')} -> {example.get('target.language')}\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yeE3B_MdJxAu" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sourcetargetsource.languagetarget.language
0Hazaba ikindi cyerekanwa cyiyi firime mumasaha abiri. Shimirwa nyagasani, imbaraga zawe. David ashobora gufungwa burundu. Akajya ayatekereza. Maze abungeri basubirayo bahimbaza, bashima Imana.luglug
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "salt.utils.show_dataset(train_ds, audio_features=['source'], N=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Whisper Model and Data Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9BNxLEzRpNey" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f876e09209a64c8b9b6a02702f4bf0df", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading weights: 0%| | 0/1259 [00:00 Dict[str, torch.Tensor]:\n", + " # split inputs and labels since they have to be of different lengths and need different padding methods\n", + " # first treat the audio inputs by simply returning torch tensors\n", + " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features] \n", + " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", + "\n", + " # get the tokenized label sequences\n", + " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", + " # pad the labels to max length\n", + " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", + "\n", + " # replace padding with -100 to ignore loss correctly\n", + " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", + "\n", + " # if bos token is appended in previous tokenization step,\n", + " # cut bos token here as it's append later anyways\n", + " if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():\n", + " labels = labels[:, 1:]\n", + "\n", + " batch[\"labels\"] = labels\n", + "\n", + " return batch\n", + "\n", + "data_collator = DataCollatorSpeechSeq2SeqWithPadding(\n", + " processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Read in prompts: preceding text which is used to guide the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# sentences = datasets.load_dataset(\n", + "# 'Sunbird/salt', 'text-all', split='train').to_pandas()\n", + "# prompts = datasets.load_dataset(\n", + "# 'Sunbird/prompts', split='train').to_pandas()\n", + "# joined = pd.merge(sentences, prompts, on='id', how='inner')\n", + "# SALT_PROMPT_LANGUAGES = ['eng', 'ach', 'lgg', 'lug', 'nyn', 'teo']\n", + "# sentence_to_prompt = {}\n", + "# for language in SALT_PROMPT_LANGUAGES:\n", + "# sentence_to_prompt[language] = dict(\n", + "# zip(joined[f'{language}_text'], joined[f'{language}_prompt']))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4mzVFDogXgLG" + }, + "outputs": [], + "source": [ + "language_id_tokens = salt.constants.SALT_LANGUAGE_TOKENS_WHISPER\n", + "\n", + "def prepare_dataset(example, p_prompt = 0.5): \n", + " audio = example[\"source\"]\n", + " input_features = feature_extractor(\n", + " audio, sampling_rate=16000, device='cuda',\n", + " do_normalize=True).input_features[0]\n", + "\n", + " # Encode target text to label ids\n", + " labels = processor.tokenizer(str(example[\"target\"])).input_ids\n", + "\n", + " # Insert the language ID token into the second position of the sequence.\n", + " labels.insert(1, language_id_tokens[example[\"target.language\"]])\n", + "\n", + " # # If a prompt is known for a particular sentence, add it to the\n", + " # # training example with probability `p_prompt`.\n", + " # if example[\"target.language\"] in sentence_to_prompt:\n", + " # prompt = sentence_to_prompt[example[\"target.language\"]].get(example[\"target\"], None)\n", + " # if prompt:\n", + " # if np.random.random() < p_prompt:\n", + " # prompt_ids = list(processor.get_prompt_ids(prompt))\n", + " # labels = prompt_ids + labels \n", + "\n", + " # Create a new dictionary with the processed data\n", + " processed_example = {\n", + " \"input_features\": input_features,\n", + " \"labels\": np.array(labels),\n", + " \"source.language\": example[\"source.language\"],\n", + " \"target.language\": example[\"target.language\"]\n", + " }\n", + "\n", + " return processed_example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "05Zyqa3cYCFW" + }, + "outputs": [], + "source": [ + "train_data = train_ds.map(prepare_dataset, remove_columns=[\"source\", \"target\"])\n", + "train_data = train_data.filter(lambda x: len(x[\"labels\"]) <= 448)\n", + "val_data = valid_ds.map(prepare_dataset, remove_columns=[\"source\", \"target\"])\n", + "val_data = val_data.filter(lambda x: len(x[\"labels\"]) <= 448)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UB4g9cW4rZ-u" + }, + "outputs": [], + "source": [ + "compute_metrics = salt.metrics.multilingual_eval_fn(\n", + " valid_ds, [evaluate.load('wer'), evaluate.load('cer')],\n", + " processor.tokenizer, log_first_N_predictions=5,\n", + " speech_processor=processor)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zCsAGEQtremE" + }, + "outputs": [], + "source": [ + "from transformers import GenerationConfig\n", + "\n", + "gen_config = GenerationConfig.from_pretrained(config['pretrained_model'])\n", + "gen_config.suppress_tokens = [] # This clears the default suppressed tokens\n", + "gen_config.max_length = config['training_args']['generation_max_length'] # maximum number of tokens to generate during evaluation and prediction\n", + "gen_config.forced_decoder_ids = None\n", + "model.generation_config = gen_config\n", + "\n", + "model.config.pad_token_id = processor.tokenizer.pad_token_id\n", + "model.config.eos_token_id = processor.tokenizer.eos_token_id\n", + "\n", + "if config['use_peft']:\n", + " model = peft.prepare_model_for_kbit_training(model)\n", + " lora_config = peft.LoraConfig(**config['lora_config'])\n", + " model.enable_input_require_grads()\n", + " model = peft.get_peft_model(model, lora_config)\n", + " model.config.use_cache = False\n", + " model.print_trainable_parameters()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Launch Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'output_dir': 'whisper-large-v3-multilingual',\n", + " 'per_device_train_batch_size': 2,\n", + " 'per_device_eval_batch_size': 2,\n", + " 'gradient_accumulation_steps': 4,\n", + " 'learning_rate': 1e-05,\n", + " 'warmup_steps': 500,\n", + " 'max_steps': 7500,\n", + " 'gradient_checkpointing': False,\n", + " 'gradient_checkpointing_kwargs': {'use_reentrant': False},\n", + " 'fp16': False,\n", + " 'bf16': True,\n", + " 'eval_strategy': 'steps',\n", + " 'predict_with_generate': True,\n", + " 'generation_max_length': 200,\n", + " 'save_steps': 50,\n", + " 'eval_steps': 50,\n", + " 'logging_steps': 50,\n", + " 'load_best_model_at_end': True,\n", + " 'metric_for_best_model': 'loss',\n", + " 'greater_is_better': False,\n", + " 'push_to_hub': True,\n", + " 'hub_model_id': 'jq/whisper-large-v3-salt-plus-xog-myx-kin-swa',\n", + " 'save_total_limit': 2}" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config[\"training_args\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026/03/05 08:00:39 ERROR mlflow.utils.async_logging.async_logging_queue: Run Id 2d488acdc39146e9af9da07c00128d49: Failed to log run data: Exception: INVALID_PARAMETER_VALUE: Changing param values is not allowed. Params were already logged='[{'key': 'dtype', 'old_value': 'float16', 'new_value': 'float32'}]' for run ID='2d488acdc39146e9af9da07c00128d49'.\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [ 501/7500 59:28 < 13:54:17, 0.14 it/s, Epoch 41.00/9223372036854775807]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossWer LugWer MeanCer LugCer Mean
5010.6967102.0690253.5150003.5150001.8480001.848000
1006.5588421.0759071.1850001.1850000.4030000.403000
1503.0183020.3359721.1020001.1020000.4620000.462000
2000.8734940.0974640.0680000.0680000.0330000.033000
2500.3090790.0438880.0190000.0190000.0050000.005000
3000.1575420.0261270.2220000.2220000.1220000.122000
3500.1165160.0199680.9860000.9860000.7600000.760000
4000.0935780.0162141.0510001.0510000.6320000.632000
4500.0941010.0154600.0100000.0100000.0030000.003000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "A custom logits processor of type has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom will take precedence. Please check the docstring of to see related `.generate()` flags.\n", + "A custom logits processor of type has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom will take precedence. Please check the docstring of to see related `.generate()` flags.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First N predictions in eval set:\n", + "Prediction (lug to lug): \" I'm going to tell you a story. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n", + "Prediction (lug to lug): \" I'm going to tell you a story. I'm going to tell you a story.\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n", + "Prediction (lug to lug): \" Ikiorizo t'yamu t'yamu t'yamu wa ubu shita. Sibu t'yaze nekereza ko, ujava aribi. Nashobura bakongera kuwa ho uyu munzi. Guchika inege, kubura apeti n'ibindi. Nzuya badata, kuwa mumu ijima, kujeza kumuchu. Umutsi mukuru uzava ni mugorobu.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n", + "Prediction (lug to lug): \" I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n", + "Prediction (lug to lug): \" I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "65fbf5ca2fda47c6a1b15048c8b66b70", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Writing model shards: 0%| | 0/1 [00:00 num_samples_to_affect: noise_start = np.random.randint(0, len(noise_sample) - num_samples_to_affect) noise_sample = noise_sample[noise_start:noise_start + num_samples_to_affect] - + + # If noise sample is empty, return the origanl audio + if len(noise_sample) == 0: + print( + f"[augment_audio_noise] Skipping empty noise sample" + f"[augment_audio_noise] len(x)={len(x)}, coverage={coverage:.2f}), num_samples_to_affect={num_samples_to_affect}" + ) + return r + # Normalize noise amplitude noise_max = np.amax(np.abs(noise_sample)) if noise_max > 0: # Avoid division by zero @@ -329,7 +345,7 @@ def augment_audio_noise(r, # Apply noise to the chosen segment x_with_noise = np.copy(x) # Make a copy of x to prevent altering the original x_with_noise[start_index:start_index + num_samples_to_affect] += noise - + r[src_or_tgt] = x_with_noise return r @@ -381,7 +397,7 @@ def clean_and_remove_punctuation( r[src_or_tgt] = ''.join([c for c in r[src_or_tgt] if c not in punct]) return r - + @single_batch_entry def lower_case(r, src_or_tgt): r[src_or_tgt] = r[src_or_tgt].lower() @@ -402,6 +418,5 @@ def set_sample_rate(r, src_or_tgt, rate, p=1.0): return r # TODO: Check that the order of preprocessing operations makes sense. For -# example, don't call match_target_sentence_format_to_source after +# example, don't call match_target_sentence_format_to_source after # prefix_dataset_tag (because then the tag is part of the text) - \ No newline at end of file