|
| 1 | +from typing import BinaryIO, Union |
| 2 | +from io import StringIO |
| 3 | +import whisperx |
| 4 | +import whisper |
| 5 | +from whisperx.utils import SubtitlesWriter, ResultWriter |
| 6 | + |
| 7 | +from app.asr_models.asr_model import ASRModel |
| 8 | +from app.config import CONFIG |
| 9 | +from app.utils import WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON |
| 10 | + |
| 11 | + |
| 12 | +class WhisperXASR(ASRModel): |
| 13 | + def __init__(self): |
| 14 | + self.x_models = dict() |
| 15 | + |
| 16 | + def load_model(self): |
| 17 | + |
| 18 | + asr_options = {"without_timestamps": False} |
| 19 | + self.model = whisperx.load_model( |
| 20 | + CONFIG.MODEL_NAME, device=CONFIG.DEVICE, compute_type="float32", asr_options=asr_options |
| 21 | + ) |
| 22 | + |
| 23 | + if CONFIG.HF_TOKEN != "": |
| 24 | + self.diarize_model = whisperx.DiarizationPipeline(use_auth_token=CONFIG.HF_TOKEN, device=CONFIG.DEVICE) |
| 25 | + |
| 26 | + def transcribe( |
| 27 | + self, |
| 28 | + audio, |
| 29 | + task: Union[str, None], |
| 30 | + language: Union[str, None], |
| 31 | + initial_prompt: Union[str, None], |
| 32 | + vad_filter: Union[bool, None], |
| 33 | + word_timestamps: Union[bool, None], |
| 34 | + options: Union[dict, None], |
| 35 | + output, |
| 36 | + ): |
| 37 | + options_dict = {"task": task} |
| 38 | + if language: |
| 39 | + options_dict["language"] = language |
| 40 | + if initial_prompt: |
| 41 | + options_dict["initial_prompt"] = initial_prompt |
| 42 | + with self.model_lock: |
| 43 | + if self.model is None: |
| 44 | + self.load_model() |
| 45 | + result = self.model.transcribe(audio, **options_dict) |
| 46 | + |
| 47 | + # Load the required model and cache it |
| 48 | + # If we transcribe models in many different languages, this may lead to OOM propblems |
| 49 | + if result["language"] in self.x_models: |
| 50 | + model_x, metadata = self.x_models[result["language"]] |
| 51 | + else: |
| 52 | + self.x_models[result["language"]] = whisperx.load_align_model( |
| 53 | + language_code=result["language"], device=CONFIG.DEVICE |
| 54 | + ) |
| 55 | + model_x, metadata = self.x_models[result["language"]] |
| 56 | + |
| 57 | + # Align whisper output |
| 58 | + result = whisperx.align( |
| 59 | + result["segments"], model_x, metadata, audio, CONFIG.DEVICE, return_char_alignments=False |
| 60 | + ) |
| 61 | + |
| 62 | + if options.get("diarize", False): |
| 63 | + if CONFIG.HF_TOKEN == "": |
| 64 | + print("Warning! HF_TOKEN is not set. Diarization may not work as expected.") |
| 65 | + min_speakers = options.get("min_speakers", None) |
| 66 | + max_speakers = options.get("max_speakers", None) |
| 67 | + # add min/max number of speakers if known |
| 68 | + diarize_segments = self.diarize_model(audio, min_speakers, max_speakers) |
| 69 | + result = whisperx.assign_word_speakers(diarize_segments, result) |
| 70 | + |
| 71 | + output_file = StringIO() |
| 72 | + self.write_result(result, output_file, output) |
| 73 | + output_file.seek(0) |
| 74 | + |
| 75 | + return output_file |
| 76 | + |
| 77 | + def language_detection(self, audio): |
| 78 | + # load audio and pad/trim it to fit 30 seconds |
| 79 | + audio = whisper.pad_or_trim(audio) |
| 80 | + |
| 81 | + # make log-Mel spectrogram and move to the same device as the model |
| 82 | + mel = whisper.log_mel_spectrogram(audio).to(self.model.device) |
| 83 | + |
| 84 | + # detect the spoken language |
| 85 | + with self.model_lock: |
| 86 | + if self.model is None: |
| 87 | + self.load_model() |
| 88 | + _, probs = self.model.detect_language(mel) |
| 89 | + detected_lang_code = max(probs, key=probs.get) |
| 90 | + |
| 91 | + return detected_lang_code |
| 92 | + |
| 93 | + def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]): |
| 94 | + if output == "srt": |
| 95 | + if CONFIG.HF_TOKEN != "": |
| 96 | + WriteSRT(SubtitlesWriter).write_result(result, file=file, options={}) |
| 97 | + else: |
| 98 | + WriteSRT(ResultWriter).write_result(result, file=file, options={}) |
| 99 | + elif output == "vtt": |
| 100 | + if CONFIG.HF_TOKEN != "": |
| 101 | + WriteVTT(SubtitlesWriter).write_result(result, file=file, options={}) |
| 102 | + else: |
| 103 | + WriteVTT(ResultWriter).write_result(result, file=file, options={}) |
| 104 | + elif output == "tsv": |
| 105 | + WriteTSV(ResultWriter).write_result(result, file=file, options={}) |
| 106 | + elif output == "json": |
| 107 | + WriteJSON(ResultWriter).write_result(result, file=file, options={}) |
| 108 | + elif output == "txt": |
| 109 | + WriteTXT(ResultWriter).write_result(result, file=file, options={}) |
| 110 | + else: |
| 111 | + return 'Please select an output method!' |
0 commit comments