diff --git a/klaam/utils/utils.py b/klaam/utils/utils.py index 526e62c..c172160 100644 --- a/klaam/utils/utils.py +++ b/klaam/utils/utils.py @@ -1,5 +1,6 @@ import librosa import torch +import numpy as np from klaam.external.FastSpeech2.buckwalter import bw2ar @@ -8,7 +9,15 @@ def load_file_to_data(file, srate=16_000): batch = {} - speech, sampling_rate = librosa.load(file, sr=srate) + + if isinstance(file, str): # If it's a file path + speech, sampling_rate = librosa.load(file, sr=srate) + elif isinstance(file, np.ndarray): # If it's a NumPy array + speech = file + sampling_rate = srate + else: + raise TypeError(f"Unsupported input type: {type(file)}. Expected str (file path) or np.ndarray.") + batch["speech"] = speech batch["sampling_rate"] = sampling_rate return batch