diff --git a/constants.py b/constants.py
index c7c271d..75da33d 100644
--- a/constants.py
+++ b/constants.py
@@ -80,19 +80,58 @@
SALT_LANGUAGE_TOKENS_WHISPER = {
# Exact/close mapping
- 'eng': 50259,
- 'swa': 50318,
+ "eng": 50259,
+ "fra": 50265,
+ "swa": 50318,
+ "sna": 50324,
+ "yor": 50325,
+ "som": 50326,
+ "afr": 50327,
+ "amh": 50334,
+ "mlg": 50349,
+ "lin": 50353,
+ "hau": 50354,
# Overwrite unused language tokens
- 'ach': 50357,
- 'lgg': 50356,
- 'lug': 50355,
- 'nyn': 50354,
- 'teo': 50353,
- 'xog': 50352,
- 'ttj': 50351,
- 'kin': 50350,
- 'myx': 50349,
- 'kik': 50348,
+ "ach": 50357,
+ "aka": 50356,
+ "bam": 50355,
+ "bem": 50352,
+ "ber": 50351,
+ "cgg": 50350,
+ "dag": 50348,
+ "dga": 50347,
+ "ewe": 50346,
+ "ful": 50345,
+ "ibo": 50344,
+ "kab": 50343,
+ "kau": 50342,
+ "kik": 50341,
+ "kin": 50340,
+ "kln": 50339,
+ "koo": 50338,
+ "kpo": 50337,
+ "led": 50336,
+ "lgg": 50335,
+ "lth": 50333,
+ "lug": 50332,
+ "luo": 50331,
+ "luy": 50330,
+ "myx": 50329,
+ "nbl": 50328,
+ "nya": 50323,
+ "nyn": 50322,
+ "orm": 50321,
+ "pcm": 50320,
+ "ruc": 50319,
+ "rwm": 50317,
+ "sot": 50316,
+ "teo": 50315,
+ "tsn": 50314,
+ "ttj": 50313,
+ "wol": 50312,
+ "xho": 50311,
+ "xog": 50310,
+ "zul": 50309,
}
SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION = {
diff --git a/dataset.py b/dataset.py
index a80d975..e028f05 100644
--- a/dataset.py
+++ b/dataset.py
@@ -73,10 +73,27 @@ def _add_speaker_id_studio_if_not_present(sample):
return sample
-def _load_single_huggingface_dataset(load_dataset_params):
- ds = datasets.load_dataset(**load_dataset_params)
+def _load_single_dataset(load_dataset_params, gcs_key_path=None, max_examples=None):
+ # if path contains gcs://, then load with google_default token
+ if "path" in load_dataset_params and "gcs://" in load_dataset_params["path"]:
+ if gcs_key_path is None:
+ raise ValueError(
+ "gcs_key_path must be provided when loading from Google Cloud Storage."
+ )
+ print("downloading datasets from: ", load_dataset_params["path"])
+ ds = datasets.load_dataset(
+ "parquet",
+ data_files=load_dataset_params["path"],
+ storage_options={"token": gcs_key_path},
+ )
+ ds= ds.cast_column("audio", datasets.Audio())
+ print("download completed: ", load_dataset_params["path"])
+ # otherwise load from hugging face with the provided params
+ else:
+ ds = datasets.load_dataset(**load_dataset_params)
+
if isinstance(ds, datasets.DatasetDict):
- split_names = list(ds.data.keys())
+ split_names = list(ds.keys())
# If the split wasn't specified, but there's only one, then just go
# ahead and load that one.
if len(split_names) == 1:
@@ -88,6 +105,9 @@ def _load_single_huggingface_dataset(load_dataset_params):
f"{load_dataset_params}. Splits found: {list(ds.keys())}."
)
+ if max_examples is not None:
+ ds = ds.select(range(min(len(ds), max_examples)))
+
remap_names = {
"audio_language": "language",
}
@@ -206,30 +226,42 @@ def _dataset_id_from_config(load_params):
return "_".join(tag)
-def _load_huggingface_datasets(config):
- """Retrieve all specified HuggingFace datasets and return as a list."""
+def _load_datasets(config):
+ """Retrieve all specified datasets and return as a list."""
loaded_datasets = []
- if "huggingface_load" not in config:
+ if "datasets" not in config and "huggingface_load" not in config:
raise ValueError(
- "There should be a `huggingface_load` entry in the dataset config, "
+ "There should be a `datasets` or `huggingface_load` entry in the dataset config, "
f"specifying which datasets to download. Got: {config}."
)
- load_list = config["huggingface_load"]
+ load_list = config["datasets"] if "datasets" in config else config["huggingface_load"]
+ gcs_key_path = config.get("gcs_key_path")
+ max_examples = config.get("max_examples_per_dataset")
# Optionally pre-download everything at once
if config.get("download_datasets_in_parallel"):
+ # Disable progress bars to prevent tqdm thread-safety issues (e.g. `_lock` AttributeError)
+ datasets.disable_progress_bar()
+ # Explicitly initialize the tqdm lock in the main thread to prevent concurrent `del tqdm_class._lock` bugs
+ import tqdm
+ tqdm.tqdm.get_lock()
+
threads = []
for l in _ensure_list(load_list):
if "join" in l:
for i in (0, 1):
thread = threading.Thread(
- target=_load_single_huggingface_dataset, args=(l["join"][i],)
+ target=_load_single_dataset,
+ args=(l["join"][i],),
+ kwargs={"gcs_key_path": gcs_key_path},
)
threads.append(thread)
else:
thread = threading.Thread(
- target=_load_single_huggingface_dataset, args=(l,)
+ target=_load_single_dataset,
+ args=(l,),
+ kwargs={"gcs_key_path": gcs_key_path},
)
threads.append(thread)
thread.start()
@@ -243,8 +275,12 @@ def _load_huggingface_datasets(config):
"If a dataset join is specified, then there should be a "
f"list of exactly two datasets to be joined. Got: {l}."
)
- left = _load_single_huggingface_dataset(l["join"][0])
- right = _load_single_huggingface_dataset(l["join"][1])
+ left = _load_single_dataset(
+ l["join"][0], gcs_key_path=gcs_key_path, max_examples=max_examples
+ )
+ right = _load_single_dataset(
+ l["join"][1], gcs_key_path=gcs_key_path, max_examples=max_examples
+ )
generator_function = lambda: _combine_datasets_generator(left, right)
ds = datasets.IterableDataset.from_generator(generator_function)
@@ -254,7 +290,9 @@ def _load_huggingface_datasets(config):
+ _dataset_id_from_config(l["join"][1])
)
else:
- ds = _load_single_huggingface_dataset(l)
+ ds = _load_single_dataset(
+ l, gcs_key_path=gcs_key_path, max_examples=max_examples
+ )
dataset_id = _dataset_id_from_config(l)
loaded_datasets.append([ds, dataset_id])
@@ -443,13 +481,12 @@ def _matching_pairs(row, config):
yield example
-def _create_generator(config, verbose=False):
+def _create_generator(loaded_datasets, config, verbose=False):
"""Make a generator that yields examples according to dataset spec."""
- huggingface_datasets = _load_huggingface_datasets(config)
if verbose:
total_row_count = 0
- for ds, id in huggingface_datasets:
+ for ds, id in loaded_datasets:
row_count = len(ds)
total_row_count += row_count
print(f"{id}: {row_count} rows")
@@ -459,29 +496,40 @@ def _yield_matches(batch, config, dataset_id):
keys = list(batch.keys())
rows = [{k: batch[k][i] for k in keys} for i in range(len(batch[keys[0]]))]
for row in rows:
- # The audio SALT datasets are in a slightly different format
- # to the translation data, each row having a 'text' and 'language'
- # field.
- if "audio" in row and "text" in row:
- row[row["language"] + "_text"] = row["text"]
- del row["text"]
- for match in _matching_pairs(row | {"origin_dataset": dataset_id}, config):
- yield match
+ if config.get("skip_matching_asr"):
+ audio_array, sample_rate = _get_audio_from_row(row)
+ example = {
+ "source": audio_array,
+ "source.sample_rate": sample_rate,
+ "source.language": row.get("language"),
+ "target": row.get("text"),
+ "target.language": row.get("language"),
+ }
+ yield example
+ else:
+ # The audio SALT datasets are in a slightly different format
+ # to the translation data, each row having a 'text' and 'language'
+ # field.
+ if "audio" in row and "text" in row:
+ row[row["language"] + "_text"] = row["text"]
+ del row["text"]
+ for match in _matching_pairs(row | {"origin_dataset": dataset_id}, config):
+ yield match
# PyArrow data should be read in batches for speed.
PYARROW_BATCH_SIZE = 10
num_workers = config.get("num_workers", 4)
- if config.get("shuffle") and len(huggingface_datasets) > 1:
+ if config.get("shuffle") and len(loaded_datasets) > 1:
# If there are multiple datasets concatenated and 'shuffle' is
# specified, then we want to randomly interleave them.
iterators = [
- d[0].iter(batch_size=PYARROW_BATCH_SIZE) for d in huggingface_datasets
+ d[0].iter(batch_size=PYARROW_BATCH_SIZE) for d in loaded_datasets
]
iterator_order = []
- for i in range(len(huggingface_datasets)):
+ for i in range(len(loaded_datasets)):
num_batches = math.ceil(
- len(huggingface_datasets[i][0]) / PYARROW_BATCH_SIZE
+ len(loaded_datasets[i][0]) / PYARROW_BATCH_SIZE
)
iterator_order.extend([i] * num_batches)
permutation = np.random.permutation(len(iterator_order))
@@ -502,13 +550,13 @@ def process_batch(args):
batch = next(iterators[iterator_id])
future = executor.submit(
process_batch,
- (batch, config, huggingface_datasets[iterator_id][1]),
+ (batch, config, loaded_datasets[iterator_id][1]),
)
future_queue.put((future, iterator_id))
except StopIteration:
break
except Exception as e:
- print("Error reading from " + huggingface_datasets[iterator_id][1])
+ print("Error reading from " + loaded_datasets[iterator_id][1])
raise
while not future_queue.empty():
future, iterator_id = future_queue.get()
@@ -520,13 +568,13 @@ def process_batch(args):
batch = next(iterators[iterator_id])
next_future = executor.submit(
process_batch,
- (batch, config, huggingface_datasets[iterator_id][1]),
+ (batch, config, loaded_datasets[iterator_id][1]),
)
future_queue.put((next_future, iterator_id))
except StopIteration:
continue
except Exception as e:
- print("Error reading from " + huggingface_datasets[iterator_id][1])
+ print("Error reading from " + loaded_datasets[iterator_id][1])
raise
elif config.get("shuffle"):
# Single dataset, shuffle is True: parallelize batches, order doesn't matter
@@ -536,7 +584,7 @@ def process_batch(args):
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
- for ds, dataset_id in huggingface_datasets:
+ for ds, dataset_id in loaded_datasets:
for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE):
futures.append(
executor.submit(process_batch, (batch, config, dataset_id))
@@ -545,10 +593,43 @@ def process_batch(args):
for match in future.result():
yield match
else:
- # No shuffle: preserve strict order, process sequentially
- for ds, dataset_id in huggingface_datasets:
- for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE):
- yield from _yield_matches(batch, config, dataset_id)
+ # No shuffle: preserve strict order, parallelize processing
+ def process_batch(args):
+ batch, config, dataset_id = args
+ return list(_yield_matches(batch, config, dataset_id))
+
+ prefetch_limit = 4 * num_workers
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ future_queue = queue.Queue()
+
+ def batch_generator():
+ for ds, dataset_id in loaded_datasets:
+ for batch in ds.iter(batch_size=PYARROW_BATCH_SIZE):
+ yield batch, config, dataset_id
+
+ batch_iter = batch_generator()
+
+ # Prefill the queue
+ for _ in range(prefetch_limit):
+ try:
+ args = next(batch_iter)
+ future = executor.submit(process_batch, args)
+ future_queue.put(future)
+ except StopIteration:
+ break
+
+ while not future_queue.empty():
+ future = future_queue.get()
+ for match in future.result():
+ yield match
+
+ # Submit the next batch if available
+ try:
+ args = next(batch_iter)
+ next_future = executor.submit(process_batch, args)
+ future_queue.put(next_future)
+ except StopIteration:
+ continue
def _compose(functions):
@@ -627,7 +708,7 @@ def create(config, verbose=False):
Create a dataset from the given configuration.
Args:
- huggingface_load : Dict containing keyword arguments to HuggingFace
+ datasets : Dict containing keyword arguments to
datasets.load_dataset(), or a list of dicts to load multiple
datasets. The dataset should be in SALT format, as per
hf.co/datasets/sunbird/salt. Common Voice is also supported, if
@@ -662,8 +743,12 @@ def create(config, verbose=False):
f"{language}. Change to [{language}] to make it a list."
)
- generator_function = lambda: _create_generator(config, verbose=verbose)
+ loaded_datasets = _load_datasets(config)
+ total_row_count = sum(len(ds) for ds, id in loaded_datasets)
+
+ generator_function = lambda: _create_generator(loaded_datasets, config, verbose=verbose)
ds = datasets.IterableDataset.from_generator(generator_function)
+ ds.approx_row_count = total_row_count
# The individual datasets are already shuffled as needed, but do a little
# more so that consecutive samples are from different batches.
@@ -678,4 +763,5 @@ def create(config, verbose=False):
ds = ds.select_columns(
["source", "target", "source.language", "target.language"]
)
+ ds.approx_row_count = total_row_count
return ds
diff --git a/metrics.py b/metrics.py
index 0489a1f..f3d74b7 100644
--- a/metrics.py
+++ b/metrics.py
@@ -4,6 +4,7 @@
import pandas as pd
import functools
import numbers
+import tqdm
def _normalise(string_list,
@@ -82,7 +83,7 @@ def _round_if_float(f, p):
f'True label: "{decoded_labels[i]}"')
subsets = {}
- for i in range(len(decoded_predictions)):
+ for i in tqdm.tqdm(range(len(decoded_predictions)), desc="Grouping language subsets"):
if speech_processor:
# For speech metrics, such as WER, we evaluate for separate target
# languages.
@@ -146,9 +147,25 @@ def multilingual_eval_fn(eval_dataset,
If `speech_processor` is defined, then it is used to decode the predictions.
Otherwise `tokenizer` is used (e.g. for translation tasks).'''
- df = pd.DataFrame(eval_dataset)
- source_language = list(df['source.language'])
- target_language = list(df['target.language'])
+ try:
+ from . import dataset as salt_dataset
+ except ImportError:
+ import dataset as salt_dataset
+
+ original_get_audio = salt_dataset._get_audio_from_row
+
+ try:
+ # Mock audio extraction to bypass heavy CPU decoding
+ salt_dataset._get_audio_from_row = lambda row: (np.zeros(1), 16000)
+
+ source_language = []
+ target_language = []
+ for item in tqdm.tqdm(eval_dataset, desc="Extracting source and target languages"):
+ source_language.append(item['source.language'])
+ target_language.append(item['target.language'])
+ finally:
+ # Restore the original method so Trainer evaluation loop works correctly
+ salt_dataset._get_audio_from_row = original_get_audio
metric_names = []
for m in metrics:
diff --git a/notebooks/training/whisper_finetuning_gcs.ipynb b/notebooks/training/whisper_finetuning_gcs.ipynb
new file mode 100644
index 0000000..3acb613
--- /dev/null
+++ b/notebooks/training/whisper_finetuning_gcs.ipynb
@@ -0,0 +1,1035 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Working on RTX 6000Ada 48GB (per-device batch size 2) and H100 80GB (per-device batch size 16)\n",
+ "\n",
+ "Before running cells below, please open a terminal and run the following command under the repo's root directory to set up the environment:\n",
+ "```\n",
+ "bash whisper_training_setup.sh\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4wWOTwrVCqnF"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (7.2.1)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0mRequirement already satisfied: pynvml in /usr/local/lib/python3.12/dist-packages (13.0.1)\n",
+ "Requirement already satisfied: nvidia-ml-py>=12.0.0 in /usr/local/lib/python3.12/dist-packages (from pynvml) (13.590.48)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Enter the MLFLOW_TRACKING_USERNAME: ········\n",
+ "Enter the MLFLOW_TRACKING_PASSWORD: ········\n"
+ ]
+ }
+ ],
+ "source": [
+ "use_wandb = False\n",
+ "use_mlflow = True\n",
+ "\n",
+ "import importlib.metadata\n",
+ "installed = [\n",
+ " dist.metadata['Name']\n",
+ " for dist in importlib.metadata.distributions()\n",
+ "]\n",
+ "\n",
+ "if use_wandb:\n",
+ " !pip install wandb\n",
+ " import wandb\n",
+ " %set_env WANDB_LOG_MODEL=True\n",
+ " %set_env WANDB_WATCH=all\n",
+ " %set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb\n",
+ " wandb.login()\n",
+ "\n",
+ "if use_mlflow:\n",
+ " if 'mlflow' not in installed:\n",
+ " !pip install mlflow\n",
+ " ## requirements to log system/GPU metrics in mlflow\n",
+ " !pip install psutil\n",
+ " !pip install pynvml\n",
+ " import os\n",
+ " from getpass import getpass\n",
+ " import mlflow\n",
+ " import mlflow.pytorch\n",
+ " from mlflow import MlflowClient\n",
+ "\n",
+ " mlflow.enable_system_metrics_logging()\n",
+ "\n",
+ " # Set MLflow tracking credentials\n",
+ " MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')\n",
+ " os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME\n",
+ "\n",
+ " MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')\n",
+ " os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD\n",
+ "\n",
+ " # Set the MLflow tracking URI\n",
+ " mlflow.set_tracking_uri('https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/')\n",
+ " mlflow.system_metrics.enable_system_metrics_logging()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "6KKzjwqnb6Dh"
+ },
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "from pathlib import Path\n",
+ "root_dir = Path.cwd().parent.parent.parent\n",
+ "sys.path.append(str(root_dir))\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "from dataclasses import dataclass, field\n",
+ "from typing import Union, List, Dict, Any\n",
+ "import string\n",
+ "import os\n",
+ "import json\n",
+ "import datasets\n",
+ "import numpy as np\n",
+ "import yaml\n",
+ "import evaluate\n",
+ "import salt.dataset\n",
+ "import salt.metrics\n",
+ "import salt.constants\n",
+ "from salt.utils import DataCollatorCTCWithPadding as dcwp\n",
+ "import huggingface_hub\n",
+ "import peft\n",
+ "import pandas as pd\n",
+ "import gcsfs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "IlbdSLfKNYfF",
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "huggingface_hub.notebook_login()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "List all the available datasets in google cloud buckets for training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "gcs = gcsfs.GCSFileSystem(project='sb-gcp-project-01', token='gcs-data-viewer-key.json')\n",
+ "bucket_path = \"sunflower-data/speech\"\n",
+ "base_gcs = f\"gcs://{bucket_path}\"\n",
+ "all_datasets = gcs.ls(path=bucket_path, detail=False)\n",
+ "all_datasets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Create Datasets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the yaml file that contains specific datasets and processing methods for training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('configs/whisper_finetuning_gcs.yaml', 'r') as file:\n",
+ " config = yaml.safe_load(file)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# # show the datasets not included for training (sometimes just empty folders)\n",
+ "# train_datasets = set([ds['path'].replace('/train/*.parquet', '').replace('gcs://', '') \n",
+ "# for ds in config['train']['datasets']])\n",
+ "# set(all_datasets) - train_datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "i1Vd4A4UIwLk"
+ },
+ "outputs": [],
+ "source": [
+ "train_ds = salt.dataset.create(config['train'], verbose=True)\n",
+ "valid_ds = salt.dataset.create(config['validation'], verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Inspect the processed datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--- Example 0 ---\n",
+ "source: [ 2.0106706e-06 3.2367443e-06 -2.7039619e-06 ... -1.4706944e-06\n",
+ " 2.6853128e-05 1.4958448e-05]\n",
+ "target: Ikiraro cya gari ya moshi kimaze kubakwa hejuru yuruzi. Ejo bundi nasanze igitabo cyanditswe na data. Ugomba guhindura uruhushya rwawe rwa kera kugirango ubone urundi rushya. Yemeza mushiki wa Denis, Ndayishimiye yakomeje avuga ko hari ingaruka nyinshi. Maze bigafasha mwarimu uburyo yatangamo isomo rye neza.\n",
+ "Languages: lug -> lug\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i, example in enumerate(train_ds.take(1)):\n",
+ " print(f\"--- Example {i} ---\")\n",
+ " \n",
+ " # Check if 'source' or 'target' contains large audio arrays\n",
+ " # If they are dicts with 'array' keys, print just the shape/metadata\n",
+ " for key in ['source', 'target']:\n",
+ " data = example.get(key)\n",
+ " if isinstance(data, dict) and 'array' in data:\n",
+ " # Avoid printing the massive raw audio array\n",
+ " print(f\"{key}: [Audio Data] Sample Rate: {data.get('sampling_rate')}, Shape: {data['array'].shape}\")\n",
+ " else:\n",
+ " print(f\"{key}: {data}\")\n",
+ " \n",
+ " print(f\"Languages: {example.get('source.language')} -> {example.get('target.language')}\\n\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yeE3B_MdJxAu"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " source | \n",
+ " target | \n",
+ " source.language | \n",
+ " target.language | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " | \n",
+ " Hazaba ikindi cyerekanwa cyiyi firime mumasaha abiri. Shimirwa nyagasani, imbaraga zawe. David ashobora gufungwa burundu. Akajya ayatekereza. Maze abungeri basubirayo bahimbaza, bashima Imana. | \n",
+ " lug | \n",
+ " lug | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "salt.utils.show_dataset(train_ds, audio_features=['source'], N=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load Whisper Model and Data Pipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9BNxLEzRpNey"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f876e09209a64c8b9b6a02702f4bf0df",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading weights: 0%| | 0/1259 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(\n",
+ " config['pretrained_model'])\n",
+ "processor = transformers.WhisperProcessor.from_pretrained(\n",
+ " config['pretrained_model'], language=None, task=\"transcribe\")\n",
+ "model = transformers.WhisperForConditionalGeneration.from_pretrained(\n",
+ " config['pretrained_model'], torch_dtype=torch.float32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "00Jd-YTThouQ"
+ },
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
+ " processor: Any\n",
+ " decoder_start_token_id: int\n",
+ "\n",
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
+ " # split inputs and labels since they have to be of different lengths and need different padding methods\n",
+ " # first treat the audio inputs by simply returning torch tensors\n",
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features] \n",
+ " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
+ "\n",
+ " # get the tokenized label sequences\n",
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
+ " # pad the labels to max length\n",
+ " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
+ "\n",
+ " # replace padding with -100 to ignore loss correctly\n",
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
+ "\n",
+ " # if bos token is appended in previous tokenization step,\n",
+ " # cut bos token here as it's append later anyways\n",
+ " if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():\n",
+ " labels = labels[:, 1:]\n",
+ "\n",
+ " batch[\"labels\"] = labels\n",
+ "\n",
+ " return batch\n",
+ "\n",
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(\n",
+ " processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Read in prompts: preceding text which is used to guide the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# sentences = datasets.load_dataset(\n",
+ "# 'Sunbird/salt', 'text-all', split='train').to_pandas()\n",
+ "# prompts = datasets.load_dataset(\n",
+ "# 'Sunbird/prompts', split='train').to_pandas()\n",
+ "# joined = pd.merge(sentences, prompts, on='id', how='inner')\n",
+ "# SALT_PROMPT_LANGUAGES = ['eng', 'ach', 'lgg', 'lug', 'nyn', 'teo']\n",
+ "# sentence_to_prompt = {}\n",
+ "# for language in SALT_PROMPT_LANGUAGES:\n",
+ "# sentence_to_prompt[language] = dict(\n",
+ "# zip(joined[f'{language}_text'], joined[f'{language}_prompt']))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4mzVFDogXgLG"
+ },
+ "outputs": [],
+ "source": [
+ "language_id_tokens = salt.constants.SALT_LANGUAGE_TOKENS_WHISPER\n",
+ "\n",
+ "def prepare_dataset(example, p_prompt = 0.5): \n",
+ " audio = example[\"source\"]\n",
+ " input_features = feature_extractor(\n",
+ " audio, sampling_rate=16000, device='cuda',\n",
+ " do_normalize=True).input_features[0]\n",
+ "\n",
+ " # Encode target text to label ids\n",
+ " labels = processor.tokenizer(str(example[\"target\"])).input_ids\n",
+ "\n",
+ " # Insert the language ID token into the second position of the sequence.\n",
+ " labels.insert(1, language_id_tokens[example[\"target.language\"]])\n",
+ "\n",
+ " # # If a prompt is known for a particular sentence, add it to the\n",
+ " # # training example with probability `p_prompt`.\n",
+ " # if example[\"target.language\"] in sentence_to_prompt:\n",
+ " # prompt = sentence_to_prompt[example[\"target.language\"]].get(example[\"target\"], None)\n",
+ " # if prompt:\n",
+ " # if np.random.random() < p_prompt:\n",
+ " # prompt_ids = list(processor.get_prompt_ids(prompt))\n",
+ " # labels = prompt_ids + labels \n",
+ "\n",
+ " # Create a new dictionary with the processed data\n",
+ " processed_example = {\n",
+ " \"input_features\": input_features,\n",
+ " \"labels\": np.array(labels),\n",
+ " \"source.language\": example[\"source.language\"],\n",
+ " \"target.language\": example[\"target.language\"]\n",
+ " }\n",
+ "\n",
+ " return processed_example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "05Zyqa3cYCFW"
+ },
+ "outputs": [],
+ "source": [
+ "train_data = train_ds.map(prepare_dataset, remove_columns=[\"source\", \"target\"])\n",
+ "train_data = train_data.filter(lambda x: len(x[\"labels\"]) <= 448)\n",
+ "val_data = valid_ds.map(prepare_dataset, remove_columns=[\"source\", \"target\"])\n",
+ "val_data = val_data.filter(lambda x: len(x[\"labels\"]) <= 448)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "UB4g9cW4rZ-u"
+ },
+ "outputs": [],
+ "source": [
+ "compute_metrics = salt.metrics.multilingual_eval_fn(\n",
+ " valid_ds, [evaluate.load('wer'), evaluate.load('cer')],\n",
+ " processor.tokenizer, log_first_N_predictions=5,\n",
+ " speech_processor=processor)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zCsAGEQtremE"
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import GenerationConfig\n",
+ "\n",
+ "gen_config = GenerationConfig.from_pretrained(config['pretrained_model'])\n",
+ "gen_config.suppress_tokens = [] # This clears the default suppressed tokens\n",
+ "gen_config.max_length = config['training_args']['generation_max_length'] # maximum number of tokens to generate during evaluation and prediction\n",
+ "gen_config.forced_decoder_ids = None\n",
+ "model.generation_config = gen_config\n",
+ "\n",
+ "model.config.pad_token_id = processor.tokenizer.pad_token_id\n",
+ "model.config.eos_token_id = processor.tokenizer.eos_token_id\n",
+ "\n",
+ "if config['use_peft']:\n",
+ " model = peft.prepare_model_for_kbit_training(model)\n",
+ " lora_config = peft.LoraConfig(**config['lora_config'])\n",
+ " model.enable_input_require_grads()\n",
+ " model = peft.get_peft_model(model, lora_config)\n",
+ " model.config.use_cache = False\n",
+ " model.print_trainable_parameters()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Launch Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'output_dir': 'whisper-large-v3-multilingual',\n",
+ " 'per_device_train_batch_size': 2,\n",
+ " 'per_device_eval_batch_size': 2,\n",
+ " 'gradient_accumulation_steps': 4,\n",
+ " 'learning_rate': 1e-05,\n",
+ " 'warmup_steps': 500,\n",
+ " 'max_steps': 7500,\n",
+ " 'gradient_checkpointing': False,\n",
+ " 'gradient_checkpointing_kwargs': {'use_reentrant': False},\n",
+ " 'fp16': False,\n",
+ " 'bf16': True,\n",
+ " 'eval_strategy': 'steps',\n",
+ " 'predict_with_generate': True,\n",
+ " 'generation_max_length': 200,\n",
+ " 'save_steps': 50,\n",
+ " 'eval_steps': 50,\n",
+ " 'logging_steps': 50,\n",
+ " 'load_best_model_at_end': True,\n",
+ " 'metric_for_best_model': 'loss',\n",
+ " 'greater_is_better': False,\n",
+ " 'push_to_hub': True,\n",
+ " 'hub_model_id': 'jq/whisper-large-v3-salt-plus-xog-myx-kin-swa',\n",
+ " 'save_total_limit': 2}"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "config[\"training_args\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2026/03/05 08:00:39 ERROR mlflow.utils.async_logging.async_logging_queue: Run Id 2d488acdc39146e9af9da07c00128d49: Failed to log run data: Exception: INVALID_PARAMETER_VALUE: Changing param values is not allowed. Params were already logged='[{'key': 'dtype', 'old_value': 'float16', 'new_value': 'float32'}]' for run ID='2d488acdc39146e9af9da07c00128d49'.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 501/7500 59:28 < 13:54:17, 0.14 it/s, Epoch 41.00/9223372036854775807]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | Step | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ " Wer Lug | \n",
+ " Wer Mean | \n",
+ " Cer Lug | \n",
+ " Cer Mean | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 50 | \n",
+ " 10.696710 | \n",
+ " 2.069025 | \n",
+ " 3.515000 | \n",
+ " 3.515000 | \n",
+ " 1.848000 | \n",
+ " 1.848000 | \n",
+ "
\n",
+ " \n",
+ " | 100 | \n",
+ " 6.558842 | \n",
+ " 1.075907 | \n",
+ " 1.185000 | \n",
+ " 1.185000 | \n",
+ " 0.403000 | \n",
+ " 0.403000 | \n",
+ "
\n",
+ " \n",
+ " | 150 | \n",
+ " 3.018302 | \n",
+ " 0.335972 | \n",
+ " 1.102000 | \n",
+ " 1.102000 | \n",
+ " 0.462000 | \n",
+ " 0.462000 | \n",
+ "
\n",
+ " \n",
+ " | 200 | \n",
+ " 0.873494 | \n",
+ " 0.097464 | \n",
+ " 0.068000 | \n",
+ " 0.068000 | \n",
+ " 0.033000 | \n",
+ " 0.033000 | \n",
+ "
\n",
+ " \n",
+ " | 250 | \n",
+ " 0.309079 | \n",
+ " 0.043888 | \n",
+ " 0.019000 | \n",
+ " 0.019000 | \n",
+ " 0.005000 | \n",
+ " 0.005000 | \n",
+ "
\n",
+ " \n",
+ " | 300 | \n",
+ " 0.157542 | \n",
+ " 0.026127 | \n",
+ " 0.222000 | \n",
+ " 0.222000 | \n",
+ " 0.122000 | \n",
+ " 0.122000 | \n",
+ "
\n",
+ " \n",
+ " | 350 | \n",
+ " 0.116516 | \n",
+ " 0.019968 | \n",
+ " 0.986000 | \n",
+ " 0.986000 | \n",
+ " 0.760000 | \n",
+ " 0.760000 | \n",
+ "
\n",
+ " \n",
+ " | 400 | \n",
+ " 0.093578 | \n",
+ " 0.016214 | \n",
+ " 1.051000 | \n",
+ " 1.051000 | \n",
+ " 0.632000 | \n",
+ " 0.632000 | \n",
+ "
\n",
+ " \n",
+ " | 450 | \n",
+ " 0.094101 | \n",
+ " 0.015460 | \n",
+ " 0.010000 | \n",
+ " 0.010000 | \n",
+ " 0.003000 | \n",
+ " 0.003000 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "A custom logits processor of type has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom will take precedence. Please check the docstring of to see related `.generate()` flags.\n",
+ "A custom logits processor of type has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom will take precedence. Please check the docstring of to see related `.generate()` flags.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \" I'm going to tell you a story. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I was born in a village called Mokpe. I\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \" I'm going to tell you a story. I'm going to tell you a story.\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \" Ikiorizo t'yamu t'yamu t'yamu wa ubu shita. Sibu t'yaze nekereza ko, ujava aribi. Nashobura bakongera kuwa ho uyu munzi. Guchika inege, kubura apeti n'ibindi. Nzuya badata, kuwa mumu ijima, kujeza kumuchu. Umutsi mukuru uzava ni mugorobu.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \" I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me. I'm going to eat the food that you gave me.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \" I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife. I'm going to tell you about the story of the man who killed his wife.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "65fbf5ca2fda47c6a1b15048c8b66b70",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \"Buberwa n'imandii ku mu biri. Guvusuje hii bisabdeyi heko soo mu. Nkunelelo nguyu mu ahami, mihitiye mu kandi n'insabiyo. Hari ibyo ashinse atahoze neza ibyo cyoza. Cicinda byacuze ko aricam n'ira bayacana byo byo bisacay.\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \"Nize irika ba hanga ba muhe baza vumburavu ba umuti wa cyida. Muri izo nka mba harimo n'ize indi. Abano n'abari q'uibaza uku riki w'ihuwa n'ibyo. Nk'uumyaan byo ka bagorebeza q'uisi baba ibo ostoma. Uru bendo rwa mberu ba Citanic, ubu waxe nze neze. Nda muhe inyuuraho mu gi tondo no mu saye. Uta yeda kuwa kuibu ko ubihiye ku ribyo.\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Ikyo rizo cyamu cyamu cyangwa ugu shita. Sini geze inhekereza ko byaba aribibi. Nashobora ga kongera ku baho uyu mu nsi. Gucika inhege kubura apeti n'ibindi. Inzu ya ba data kuva mu mu ijima ku geza ku mu cyo. Umo cyi mu kuru uzaba n'imugoroho.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Ni nkwebera ko we wa mwence u bwence. Ntisi si mahuma. Cagicuva cibumbo go cyaramuka cyaji. Ntaho baramuka bashate ku kugira nabi. Batera n'iyako aska uba gandira ku kubirongo incaansi bekaandiba ka bazanyarabibazo bishiyo iyeko ubi biriha.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \"Ibyo bina n'izo n'uaro da tu mu n'abashya u-ma ra iyo in'yaka akarahu ka n'anayi. So goku razaba y'arasoniye bihiri incu bici mi n'uyangera ku isoma. Nagiye mu bwisho in'cashyade in'yaki bici. Arian ugo ba qoreshi indi mu mu quiri indi n'wara. Kuberu mu baano e asangaa wa yagizo ruo ahari. Yemere waxa andaga shishikarizwa guko meza gukora.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4d83778823ed4d688ece2803093085a8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \"Ruberwa n' imanzi ku mubiri. Nkuzuye nkibisabye inteko zwo nkube. Nkone rero nguyu umwami nkihitiyemo kandi nkisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itizwinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \" Niseye kaba hanga bamwe baza vumbura vuba umuti wa cyida. Muri izo ngamba harimo nizindi. Abano ntabari cy'ibaza uko rikiwihuwa ntibi. Umo miyangu iyo kaba gore beza ku isi baba i Boston. Uru bendo rwa mberu ba Titanic ugo genze neza. Inda mu enyura nta mu gihondo ntumusaye. Utayega ko haqwiyi gukora bici ucribyo.\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Icyorezo cyamugyamu cyangwa ubushyita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora gakongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mu cyo. Umunsi mukuru uzaba nimugoroba.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Ini nka kubera ko wo wampenze ubwenge. Iyabo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cyo cy\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \"Ibyo bintu nizo ntaaro batumu ntabashya umarayyi miyaka acarahu kanna ntariyi. So gokura zaba y'arasomye Bibiliyi nshu bihici mi niyongera ku isoma. Nagiiyo mu bushimwa shiizu miyaka ibiri. Ari n'ubwo bakoreshi indi mu mu quirindi n'wara. Kubera umubano we asa ntaho y'agizuru hare. Yemeri w'akanda gashishikarizwa gukomeza gukora.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "265c36d8ffcf49bbb80b413aebbdda5a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \"Ruberwa n' imanzi ku mubiri. Bwuzuye hibisabye Inteko Nkuru Mube. None rero, nguyu muami nti hitiye mu kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo ntirabayazana wibyo bisasu.\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \"низere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izongamba harimo nizindi. Abantu bamwe abare kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Ini nka kubera ko wo wampenze ubwenge. Yabo cyo kubwaza nkubwaza. Incy si mahuma? Ni. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku birongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \"Ibyo bintu nizo ntwaro zatumye ntabasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa ntaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e06b00eca5594f36b49c78cd3cccdc8d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \"Nico haangaba mwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye kaba gore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo.\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ede9313d839b4390916492da43dec64c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \"NISere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'ibumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogo kurazaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho y'agize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "274c5bba85ea47e3a59c01cdfe059a0a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \".\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \".....................................................................................................................................................................................................\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \".\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bed7cc92f42e4aa7853f696b869addfb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \". Ruberwa nimanzi kumubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \"..................................................................................................................\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Icyorezo cya muryamo cyangwa ubushyita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \".\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "95a5afe42618479baab4e9a39ff6b0c2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First N predictions in eval set:\n",
+ "Prediction (lug to lug): \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\", True label: \"Ruberwa n' imanzi ku mubiri. Bwuzuye ribisabye Inteko Nkuru mu. None rero, nguyu umwami mwihitiyemo kandi mwisabiye. Hari ibyo ashinzwe atakoze neza igihe cyose. Itsinda ryavuze ko ariryo nyirabayazana w'ibyo bisasu.\"\n",
+ "Prediction (lug to lug): \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\", True label: \"[S] Nizere ko abahanga bamwe bazavumbura vuba umuti wa sida. Muri izo ngamba harimo nizindi. Abantu bamwe bari kwibaza ukuri kwibihuha nkibi. Umuntu yambwiye ko abagore beza ku isi baba i Boston. Urugendo rwa mbere rwa Titanic ntirwagenze neza. Ndamwenyura nko mugitondo numusaya. Utekereza ko hakwiye gukorwa iki kuri byo?\"\n",
+ "Prediction (lug to lug): \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\", True label: \"Icyorezo cya muryamo cyangwa ubushita. Sinigeze ntekereza ko byaba ari bibi. Ntashobora kongera kubaho uyu munsi. Gucika intege, Kubura appétit n'ibindi. Inzu ya ba data kuva mu mwijima kugeza ku mucyo. Umunsi mukuru uzaba nimugoroba.\"\n",
+ "Prediction (lug to lug): \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaka. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\", True label: \"Iyi nka kubera ko wowe wampenze ubwenge. Yabo ko bakora igikorwa cyo kubyaza. • impyisi mahuma. Ni itsinda ry'abantu bo mu muryango umwe. Cya gicuba k'i Bumbogo cyaramuka cyaje. N'aho baramuka bashatse kukugirira nabi. Bateraniye kwa Skau bakaganira ku mirongo y Ibyanditswe kandi bakabazanya ibibazo bishingiye kuri Bibiliya. Inyamaswa iyo ari yo yose cyangwa plant.\"\n",
+ "Prediction (lug to lug): \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\", True label: \"Ibyo bintu nizo ntwaro zatumye mbasha kumara iyi myaka atarahukana na rimwe. Sogokuru azaba yarasomye Bibiliya inshuro icumi niyongera kuyisoma. Nagiye mu Bushinwa hashize imyaka ibiri. Hari n'ubwo bakoresha indimu mu kwirinda indwara. Kubera umubano we asa nkaho yagize uruhare. Yemerwa kandi agashishikarizwa gukomeza gukora.\"\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "eed7e7c2982d452c998cb0e8b6a95af7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Writing model shards: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "training_args = transformers.Seq2SeqTrainingArguments(\n",
+ " **config[\"training_args\"],\n",
+ " report_to= [\n",
+ " platform for platform, use in [(\"wandb\", use_wandb), (\"mlflow\", use_mlflow)] if use]\n",
+ ")\n",
+ "\n",
+ "trainer = transformers.Seq2SeqTrainer(\n",
+ " args=training_args,\n",
+ " model=model,\n",
+ " train_dataset=train_data,\n",
+ " eval_dataset=val_data,\n",
+ " data_collator=data_collator,\n",
+ " compute_metrics=compute_metrics,\n",
+ " processing_class=processor,\n",
+ ")\n",
+ "\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%debug"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Log the config settings for reference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if use_mlflow:\n",
+ " mlflow.log_params(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save the full model (not just the adapter weights)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PoicKYTCrxew"
+ },
+ "outputs": [],
+ "source": [
+ "processor.push_to_hub(config['training_args']['hub_model_id'])\n",
+ "model.push_to_hub(config['training_args']['hub_model_id'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Try running the model on the first test example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "example = next(iter(valid_ds))\n",
+ "input_features = processor(example[\"source\"], sampling_rate=16000, return_tensors=\"pt\").input_features\n",
+ "with torch.no_grad():\n",
+ " predicted_ids = model.generate(input_features.to(\"cuda\"))[0]\n",
+ "transcription = processor.decode(predicted_ids)\n",
+ "print(transcription)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python3 (main venv)",
+ "language": "python",
+ "name": "main"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/preprocessing.py b/preprocessing.py
index d4b2cb8..eb34ebb 100644
--- a/preprocessing.py
+++ b/preprocessing.py
@@ -197,7 +197,7 @@ def augment_audio_speed(r, src_or_tgt, p=0.5, low=0.95, high=1.15):
r[src_or_tgt] = x_with_speed_change
return r
-
+
class NoiseAugmenter:
"""Class to handle noise augmentation with lazy loading of noise datasets."""
@@ -282,7 +282,7 @@ def augment_audio_noise(r,
x = r[src_or_tgt]
if not isinstance(x, np.ndarray):
x = np.array(x)
-
+
# Do nothing for empty inputs
if not len(x):
return r
@@ -296,7 +296,15 @@ def augment_audio_noise(r,
coverage = np.random.uniform(min_coverage, max_coverage)
num_samples_to_affect = int(len(x) * coverage)
start_index = np.random.randint(0, len(x) - num_samples_to_affect)
-
+
+ # Guard: if num_samples_to_affect is 0, nothing to do
+ if num_samples_to_affect == 0:
+ print(
+ f"[augment_audio_noise] Skipping: num_samples_to_affect=0 "
+ f"[augment_audio_noise] (len(x)={len(x)}, coverage={coverage:.2f})"
+ )
+ return r
+
if noise_audio_repo is None:
# Use synthetic white noise
noise = np.random.uniform(-amplitude, amplitude, size=num_samples_to_affect)
@@ -304,21 +312,29 @@ def augment_audio_noise(r,
# Get the singleton instance and load dataset if needed
noise_augmenter = NoiseAugmenter()
noise_dataset = noise_augmenter.get_noise_dataset(noise_audio_repo)
-
+
# Randomly select a noise sample
noise_idx = np.random.randint(0, noise_dataset.num_rows)
noise_sample = np.array(noise_dataset[noise_idx]['audio']['array'])
-
+
# If noise sample is too short, repeat it
if len(noise_sample) < num_samples_to_affect:
repeats = int(np.ceil(num_samples_to_affect / len(noise_sample)))
noise_sample = np.tile(noise_sample, repeats)
-
+
# If noise sample is too long, take a random segment
if len(noise_sample) > num_samples_to_affect:
noise_start = np.random.randint(0, len(noise_sample) - num_samples_to_affect)
noise_sample = noise_sample[noise_start:noise_start + num_samples_to_affect]
-
+
+ # If noise sample is empty, return the origanl audio
+ if len(noise_sample) == 0:
+ print(
+ f"[augment_audio_noise] Skipping empty noise sample"
+ f"[augment_audio_noise] len(x)={len(x)}, coverage={coverage:.2f}), num_samples_to_affect={num_samples_to_affect}"
+ )
+ return r
+
# Normalize noise amplitude
noise_max = np.amax(np.abs(noise_sample))
if noise_max > 0: # Avoid division by zero
@@ -329,7 +345,7 @@ def augment_audio_noise(r,
# Apply noise to the chosen segment
x_with_noise = np.copy(x) # Make a copy of x to prevent altering the original
x_with_noise[start_index:start_index + num_samples_to_affect] += noise
-
+
r[src_or_tgt] = x_with_noise
return r
@@ -381,7 +397,7 @@ def clean_and_remove_punctuation(
r[src_or_tgt] = ''.join([c for c in r[src_or_tgt] if c not in punct])
return r
-
+
@single_batch_entry
def lower_case(r, src_or_tgt):
r[src_or_tgt] = r[src_or_tgt].lower()
@@ -402,6 +418,5 @@ def set_sample_rate(r, src_or_tgt, rate, p=1.0):
return r
# TODO: Check that the order of preprocessing operations makes sense. For
-# example, don't call match_target_sentence_format_to_source after
+# example, don't call match_target_sentence_format_to_source after
# prefix_dataset_tag (because then the tag is part of the text)
-
\ No newline at end of file