Skip to content

Commit 30abb70

Browse files
author
eleanorTurintech
committed
Peformance improvements
1 parent 517a43e commit 30abb70

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

whisper/__init__.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
6060
if os.path.exists(download_target) and not os.path.isfile(download_target):
6161
raise RuntimeError(f"{download_target} exists and is not a regular file")
6262

63+
def compute_sha256(file_path: str) -> str:
64+
sha256 = hashlib.sha256()
65+
with open(file_path, "rb") as f:
66+
for chunk in iter(lambda: f.read(8192), b""):
67+
sha256.update(chunk)
68+
return sha256.hexdigest()
69+
6370
if os.path.isfile(download_target):
64-
with open(download_target, "rb") as f:
65-
model_bytes = f.read()
66-
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
67-
return model_bytes if in_memory else download_target
71+
if compute_sha256(download_target) == expected_sha256:
72+
if in_memory:
73+
with open(download_target, "rb") as f:
74+
return f.read()
75+
else:
76+
return download_target
6877
else:
6978
warnings.warn(
7079
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
@@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
8695
output.write(buffer)
8796
loop.update(len(buffer))
8897

89-
model_bytes = open(download_target, "rb").read()
90-
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
98+
if compute_sha256(download_target) != expected_sha256:
9199
raise RuntimeError(
92100
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
93101
)
94102

95-
return model_bytes if in_memory else download_target
103+
if in_memory:
104+
with open(download_target, "rb") as f:
105+
return f.read()
106+
else:
107+
return download_target
96108

97109

98110
def available_models() -> List[str]:
@@ -147,7 +159,7 @@ def load_model(
147159
with (
148160
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
149161
) as fp:
150-
checkpoint = torch.load(fp, map_location=device)
162+
checkpoint = torch.load(fp, map_location=device,weights_only=True)
151163
del checkpoint_file
152164

153165
dims = ModelDimensions(**checkpoint["dims"])
@@ -157,4 +169,4 @@ def load_model(
157169
if alignment_heads is not None:
158170
model.set_alignment_heads(alignment_heads)
159171

160-
return model.to(device)
172+
return model.to(device)

whisper/model.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,47 @@ def __init__(
224224
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
225225
self.register_buffer("mask", mask, persistent=False)
226226

227+
# Optimisation: pre-compute and register the mask in CUDA if available
228+
if torch.cuda.is_available():
229+
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
230+
231+
232+
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
233+
"""
234+
Args:
235+
tokens: (n_batch, n_token)
236+
audio_features: (n_batch, n_audio_ctx, n_audio_state)
237+
238+
Returns:
239+
logits: (n_batch, n_token, n_vocab)
240+
"""
241+
n_batch, n_token = tokens.shape
242+
n_audio_ctx, n_audio_state = audio_features.shape[1:]
243+
244+
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
245+
246+
# Optimisation: Move audio_features to GPU once here.
247+
if torch.cuda.is_available():
248+
audio_features = audio_features.cuda()
249+
250+
251+
for block in self.blocks:
252+
x = block(x, audio_features)
253+
254+
x = self.ln(x)
255+
logits = x @ self.token_embedding.weight.T
256+
257+
# Optimisation: Apply the precomputed CUDA mask if available.
258+
if torch.cuda.is_available():
259+
mask = self.mask_cuda[:n_token, :n_token]
260+
else:
261+
mask = self.mask[:n_token, :n_token]
262+
263+
logits = logits + mask
264+
265+
return logits
266+
267+
227268
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
228269
"""
229270
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
@@ -342,4 +383,4 @@ def install_hooks(layer: nn.Module):
342383

343384
detect_language = detect_language_function
344385
transcribe = transcribe_function
345-
decode = decode_function
386+
decode = decode_function

0 commit comments

Comments
 (0)