Skip to content

Commit 3be0c57

Browse files
committed
whisper : fix KV cache allocation
1 parent 2ef717b commit 3be0c57

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

src/whisper.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,9 @@ struct whisper_state {
822822
int32_t n_fail_p = 0; // number of logprob threshold failures
823823
int32_t n_fail_h = 0; // number of entropy threshold failures
824824

825+
// number of decoders for which we have constructed the KV cache
826+
int32_t kv_self_n_dec = 0;
827+
825828
// unified self-attention KV cache for all decoders
826829
whisper_kv_cache kv_self;
827830

@@ -3408,14 +3411,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
34083411
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
34093412
}
34103413

3411-
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3412-
// in theory, there can be a case where this is not enough, but in practice it should always be enough
3413-
const int factor = 3;
3414-
3414+
// at this point, we don't know yet how many decoders will be used
3415+
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
3416+
state->kv_self_n_dec = 1;
34153417
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
34163418
ctx->model.hparams.n_text_state,
34173419
ctx->model.hparams.n_text_layer,
3418-
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
3420+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
34193421
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
34203422
whisper_free_state(state);
34213423
return nullptr;
@@ -5780,13 +5782,34 @@ int whisper_full_with_state(
57805782
}
57815783
WHISPER_LOG_DEBUG("\n\n");
57825784

5785+
// recreate the KV cache if the number of decoders has changed
5786+
if (state->kv_self_n_dec < n_decoders_cur) {
5787+
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
5788+
5789+
whisper_kv_cache_free(state->kv_self);
5790+
5791+
// overallocate to workaround KV cache fragmentation issues
5792+
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
5793+
5794+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
5795+
ctx->model.hparams.n_text_state,
5796+
ctx->model.hparams.n_text_layer,
5797+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
5798+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
5799+
whisper_free_state(state);
5800+
return -7;
5801+
}
5802+
5803+
state->kv_self_n_dec = n_decoders_cur;
5804+
}
5805+
57835806
whisper_kv_cache_clear(state->kv_self);
57845807

57855808
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
57865809

57875810
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
57885811
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5789-
return -7;
5812+
return -8;
57905813
}
57915814

57925815
{
@@ -6086,7 +6109,7 @@ int whisper_full_with_state(
60866109

60876110
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
60886111
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
6089-
return -8;
6112+
return -9;
60906113
}
60916114

60926115
const int64_t t_start_sample_us = ggml_time_us();

0 commit comments

Comments
 (0)