Skip to content

Commit 2538c5f

Browse files
authored
Merge pull request #1 from charnesp/whisperx
Whisperx integration
2 parents f2f39fe + f4981dd commit 2538c5f

File tree

14 files changed

+351
-71
lines changed

14 files changed

+351
-71
lines changed

.devcontainer/devcontainer.json

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
2+
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose
3+
{
4+
"name": "Existing Docker Compose (Extend)",
5+
6+
// Update the 'dockerComposeFile' list if you have more compose files or use different names.
7+
// The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
8+
"dockerComposeFile": [
9+
"../docker-compose.yml",
10+
"docker-compose.yml"
11+
],
12+
13+
// The 'service' property is the name of the service for the container that VS Code should
14+
// use. Update this value and .devcontainer/docker-compose.yml to the real service name.
15+
"service": "whisper-asr-webservice",
16+
17+
// The optional 'workspaceFolder' property is the path VS Code should open by default when
18+
// connected. This is typically a file mount in .devcontainer/docker-compose.yml
19+
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
20+
21+
// "overrideCommand": "/bin/sh -c 'while sleep 1000; do :; done'"
22+
"overrideCommand": true
23+
24+
// Features to add to the dev container. More info: https://containers.dev/features.
25+
// "features": {},
26+
27+
// Use 'forwardPorts' to make a list of ports inside the container available locally.
28+
// "forwardPorts": [],
29+
30+
// Uncomment the next line if you want start specific services in your Docker Compose config.
31+
// "runServices": [],
32+
33+
// Uncomment the next line if you want to keep your containers running after VS Code shuts down.
34+
// "shutdownAction": "none",
35+
36+
// Uncomment the next line to run commands after the container is created.
37+
// "postCreateCommand": "cat /etc/os-release",
38+
39+
// Configure tool-specific properties.
40+
// "customizations": {},
41+
42+
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
43+
// "remoteUser": "devcontainer"
44+
}

.devcontainer/docker-compose.yml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
version: '3.4'
2+
services:
3+
# Update this to the name of the service you want to work with in your docker-compose.yml file
4+
whisper-asr-webservice:
5+
# Uncomment if you want to override the service's Dockerfile to one in the .devcontainer
6+
# folder. Note that the path of the Dockerfile and context is relative to the *primary*
7+
# docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile"
8+
# array). The sample below assumes your primary file is in the root of your project.
9+
#
10+
# build:
11+
# context: .
12+
# dockerfile: .devcontainer/Dockerfile
13+
env_file: .devcontainer/dev.env
14+
environment:
15+
ASR_ENGINE: ${ASR_ENGINE}
16+
HF_TOKEN: ${HF_TOKEN}
17+
18+
volumes:
19+
# Update this to wherever you want VS Code to mount the folder of your project
20+
- ..:/workspaces:cached
21+
22+
# Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust.
23+
# cap_add:
24+
# - SYS_PTRACE
25+
# security_opt:
26+
# - seccomp:unconfined
27+
28+
# Overrides default command so things don't shut down after the process ends.
29+
command: sleep infinity
30+

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@ pip-wheel-metadata
4141

4242
poetry/core/*
4343

44-
public
44+
public
45+
46+
.devcontainer/dev.env

Dockerfile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ RUN export DEBIAN_FRONTEND=noninteractive \
88
pkg-config \
99
yasm \
1010
ca-certificates \
11+
gcc \
12+
python3-dev \
1113
&& rm -rf /var/lib/apt/lists/*
1214

1315
RUN git clone https://github.com/FFmpeg/FFmpeg.git --depth 1 --branch n6.1.1 --single-branch /FFmpeg-6.1.1
@@ -42,6 +44,12 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
4244

4345
FROM python:3.10-bookworm
4446

47+
RUN export DEBIAN_FRONTEND=noninteractive \
48+
&& apt-get -qq update \
49+
&& apt-get -qq install --no-install-recommends \
50+
libsndfile1 \
51+
&& rm -rf /var/lib/apt/lists/*
52+
4553
ENV POETRY_VENV=/app/.venv
4654

4755
RUN python3 -m venv $POETRY_VENV \
@@ -61,6 +69,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
6169
RUN poetry config virtualenvs.in-project true
6270
RUN poetry install
6371

72+
RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
73+
RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
74+
&& cd whisperX \
75+
&& $POETRY_VENV/bin/pip install -e .
76+
6477
EXPOSE 9000
6578

6679
ENTRYPOINT ["whisper-asr-webservice"]

Dockerfile.gpu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
4343
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
4444

4545
ENV PYTHON_VERSION=3.10
46+
47+
RUN export DEBIAN_FRONTEND=noninteractive \
48+
&& apt-get -qq update \
49+
&& apt-get -qq install --no-install-recommends \
50+
libsndfile1 \
51+
&& rm -rf /var/lib/apt/lists/*
52+
4653
ENV POETRY_VENV=/app/.venv
4754

4855
RUN export DEBIAN_FRONTEND=noninteractive \
@@ -79,6 +86,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
7986
RUN poetry install
8087
RUN $POETRY_VENV/bin/pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch
8188

89+
RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
90+
RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
91+
&& cd whisperX \
92+
&& $POETRY_VENV/bin/pip install -e .
93+
8294
EXPOSE 9000
8395

8496
CMD whisper-asr-webservice

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Current release (v1.7.1) supports following whisper models:
1313

1414
- [openai/whisper](https://github.com/openai/whisper)@[v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
1515
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.1.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/v1.1.0)
16+
- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1)
1617

1718
## Quick Usage
1819

app/asr_models/asr_model.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ class ASRModel(ABC):
1313
"""
1414
Abstract base class for ASR (Automatic Speech Recognition) models.
1515
"""
16+
1617
model = None
18+
diarize_model = None # used for WhisperX
19+
x_models = dict() # used for WhisperX
1720
model_lock = Lock()
1821
last_activity_time = time.time()
1922

@@ -28,14 +31,17 @@ def load_model(self):
2831
pass
2932

3033
@abstractmethod
31-
def transcribe(self,
32-
audio,
33-
task: Union[str, None],
34-
language: Union[str, None],
35-
initial_prompt: Union[str, None],
36-
vad_filter: Union[bool, None],
37-
word_timestamps: Union[bool, None]
38-
):
34+
def transcribe(
35+
self,
36+
audio,
37+
task: Union[str, None],
38+
language: Union[str, None],
39+
initial_prompt: Union[str, None],
40+
vad_filter: Union[bool, None],
41+
word_timestamps: Union[bool, None],
42+
options: Union[dict, None],
43+
output,
44+
):
3945
"""
4046
Perform transcription on the given audio file.
4147
"""
@@ -52,7 +58,8 @@ def monitor_idleness(self):
5258
"""
5359
Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
5460
"""
55-
if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
61+
if CONFIG.MODEL_IDLE_TIMEOUT <= 0:
62+
return
5663
while True:
5764
time.sleep(15)
5865
if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
@@ -68,4 +75,6 @@ def release_model(self):
6875
torch.cuda.empty_cache()
6976
gc.collect()
7077
self.model = None
78+
self.diarize_model = None
79+
self.x_models = dict()
7180
print("Model unloaded due to timeout")

app/asr_models/faster_whisper_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def transcribe(
3232
initial_prompt: Union[str, None],
3333
vad_filter: Union[bool, None],
3434
word_timestamps: Union[bool, None],
35+
options: Union[dict, None],
3536
output,
3637
):
3738
self.last_activity_time = time.time()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import BinaryIO, Union
2+
from io import StringIO
3+
import whisperx
4+
import whisper
5+
from whisperx.utils import SubtitlesWriter, ResultWriter
6+
7+
from app.asr_models.asr_model import ASRModel
8+
from app.config import CONFIG
9+
from app.utils import WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
10+
11+
12+
class WhisperXASR(ASRModel):
13+
def __init__(self):
14+
self.x_models = dict()
15+
16+
def load_model(self):
17+
18+
asr_options = {"without_timestamps": False}
19+
self.model = whisperx.load_model(
20+
CONFIG.MODEL_NAME, device=CONFIG.DEVICE, compute_type="float32", asr_options=asr_options
21+
)
22+
23+
if CONFIG.HF_TOKEN != "":
24+
self.diarize_model = whisperx.DiarizationPipeline(use_auth_token=CONFIG.HF_TOKEN, device=CONFIG.DEVICE)
25+
26+
def transcribe(
27+
self,
28+
audio,
29+
task: Union[str, None],
30+
language: Union[str, None],
31+
initial_prompt: Union[str, None],
32+
vad_filter: Union[bool, None],
33+
word_timestamps: Union[bool, None],
34+
options: Union[dict, None],
35+
output,
36+
):
37+
options_dict = {"task": task}
38+
if language:
39+
options_dict["language"] = language
40+
if initial_prompt:
41+
options_dict["initial_prompt"] = initial_prompt
42+
with self.model_lock:
43+
if self.model is None:
44+
self.load_model()
45+
result = self.model.transcribe(audio, **options_dict)
46+
47+
# Load the required model and cache it
48+
# If we transcribe models in many different languages, this may lead to OOM propblems
49+
if result["language"] in self.x_models:
50+
model_x, metadata = self.x_models[result["language"]]
51+
else:
52+
self.x_models[result["language"]] = whisperx.load_align_model(
53+
language_code=result["language"], device=CONFIG.DEVICE
54+
)
55+
model_x, metadata = self.x_models[result["language"]]
56+
57+
# Align whisper output
58+
result = whisperx.align(
59+
result["segments"], model_x, metadata, audio, CONFIG.DEVICE, return_char_alignments=False
60+
)
61+
62+
if options.get("diarize", False):
63+
if CONFIG.HF_TOKEN == "":
64+
print("Warning! HF_TOKEN is not set. Diarization may not work as expected.")
65+
min_speakers = options.get("min_speakers", None)
66+
max_speakers = options.get("max_speakers", None)
67+
# add min/max number of speakers if known
68+
diarize_segments = self.diarize_model(audio, min_speakers, max_speakers)
69+
result = whisperx.assign_word_speakers(diarize_segments, result)
70+
71+
output_file = StringIO()
72+
self.write_result(result, output_file, output)
73+
output_file.seek(0)
74+
75+
return output_file
76+
77+
def language_detection(self, audio):
78+
# load audio and pad/trim it to fit 30 seconds
79+
audio = whisper.pad_or_trim(audio)
80+
81+
# make log-Mel spectrogram and move to the same device as the model
82+
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
83+
84+
# detect the spoken language
85+
with self.model_lock:
86+
if self.model is None:
87+
self.load_model()
88+
_, probs = self.model.detect_language(mel)
89+
detected_lang_code = max(probs, key=probs.get)
90+
91+
return detected_lang_code
92+
93+
def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
94+
if output == "srt":
95+
if CONFIG.HF_TOKEN != "":
96+
WriteSRT(SubtitlesWriter).write_result(result, file=file, options={})
97+
else:
98+
WriteSRT(ResultWriter).write_result(result, file=file, options={})
99+
elif output == "vtt":
100+
if CONFIG.HF_TOKEN != "":
101+
WriteVTT(SubtitlesWriter).write_result(result, file=file, options={})
102+
else:
103+
WriteVTT(ResultWriter).write_result(result, file=file, options={})
104+
elif output == "tsv":
105+
WriteTSV(ResultWriter).write_result(result, file=file, options={})
106+
elif output == "json":
107+
WriteJSON(ResultWriter).write_result(result, file=file, options={})
108+
elif output == "txt":
109+
WriteTXT(ResultWriter).write_result(result, file=file, options={})
110+
else:
111+
return 'Please select an output method!'

app/asr_models/openai_whisper_engine.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,28 @@ class OpenAIWhisperASR(ASRModel):
1616
def load_model(self):
1717

1818
if torch.cuda.is_available():
19-
self.model = whisper.load_model(
20-
name=CONFIG.MODEL_NAME,
21-
download_root=CONFIG.MODEL_PATH
22-
).cuda()
19+
self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH).cuda()
2320
else:
24-
self.model = whisper.load_model(
25-
name=CONFIG.MODEL_NAME,
26-
download_root=CONFIG.MODEL_PATH
27-
)
21+
self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH)
2822

2923
Thread(target=self.monitor_idleness, daemon=True).start()
3024

3125
def transcribe(
32-
self,
33-
audio,
34-
task: Union[str, None],
35-
language: Union[str, None],
36-
initial_prompt: Union[str, None],
37-
vad_filter: Union[bool, None],
38-
word_timestamps: Union[bool, None],
39-
output,
26+
self,
27+
audio,
28+
task: Union[str, None],
29+
language: Union[str, None],
30+
initial_prompt: Union[str, None],
31+
vad_filter: Union[bool, None],
32+
word_timestamps: Union[bool, None],
33+
options: Union[dict, None],
34+
output,
4035
):
4136
self.last_activity_time = time.time()
4237

4338
with self.model_lock:
44-
if self.model is None: self.load_model()
39+
if self.model is None:
40+
self.load_model()
4541

4642
options_dict = {"task": task}
4743
if language:
@@ -64,7 +60,8 @@ def language_detection(self, audio):
6460
self.last_activity_time = time.time()
6561

6662
with self.model_lock:
67-
if self.model is None: self.load_model()
63+
if self.model is None:
64+
self.load_model()
6865

6966
# load audio and pad/trim it to fit 30 seconds
7067
audio = whisper.pad_or_trim(audio)

0 commit comments

Comments
 (0)