@@ -822,6 +822,9 @@ struct whisper_state {
822
822
int32_t n_fail_p = 0 ; // number of logprob threshold failures
823
823
int32_t n_fail_h = 0 ; // number of entropy threshold failures
824
824
825
+ // number of decoders for which we have constructed the KV cache
826
+ int32_t kv_self_n_dec = 0 ;
827
+
825
828
// unified self-attention KV cache for all decoders
826
829
whisper_kv_cache kv_self;
827
830
@@ -3408,14 +3411,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3408
3411
whisper_mel_init (state->mel , state->backends [0 ], n_len, n_len, n_mel);
3409
3412
}
3410
3413
3411
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3412
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
3413
- const int factor = 3 ;
3414
-
3414
+ // at this point, we don't know yet how many decoders will be used
3415
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
3416
+ state->kv_self_n_dec = 1 ;
3415
3417
if (!whisper_kv_cache_init (state->kv_self , state->backends [0 ], ctx->itype ,
3416
3418
ctx->model .hparams .n_text_state ,
3417
3419
ctx->model .hparams .n_text_layer ,
3418
- GGML_PAD (ctx->model .hparams .n_text_ctx , 256 )*factor )) {
3420
+ GGML_PAD (ctx->model .hparams .n_text_ctx , 256 ))) {
3419
3421
WHISPER_LOG_ERROR (" %s: whisper_kv_cache_init() failed for self-attention cache\n " , __func__);
3420
3422
whisper_free_state (state);
3421
3423
return nullptr ;
@@ -5780,13 +5782,34 @@ int whisper_full_with_state(
5780
5782
}
5781
5783
WHISPER_LOG_DEBUG (" \n\n " );
5782
5784
5785
+ // recreate the KV cache if the number of decoders has changed
5786
+ if (state->kv_self_n_dec < n_decoders_cur) {
5787
+ WHISPER_LOG_DEBUG (" %s: recreating KV cache: n_decoders_cur = %d\n " , __func__, n_decoders_cur);
5788
+
5789
+ whisper_kv_cache_free (state->kv_self );
5790
+
5791
+ // overallocate to workaround KV cache fragmentation issues
5792
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1 ;
5793
+
5794
+ if (!whisper_kv_cache_init (state->kv_self , state->backends [0 ], ctx->itype ,
5795
+ ctx->model .hparams .n_text_state ,
5796
+ ctx->model .hparams .n_text_layer ,
5797
+ GGML_PAD (ctx->model .hparams .n_text_ctx , 256 )*factor)) {
5798
+ WHISPER_LOG_ERROR (" %s: whisper_kv_cache_init() failed for self-attention cache\n " , __func__);
5799
+ whisper_free_state (state);
5800
+ return -7 ;
5801
+ }
5802
+
5803
+ state->kv_self_n_dec = n_decoders_cur;
5804
+ }
5805
+
5783
5806
whisper_kv_cache_clear (state->kv_self );
5784
5807
5785
5808
whisper_batch_prep_legacy (state->batch , prompt.data (), prompt.size (), 0 , 0 );
5786
5809
5787
5810
if (!whisper_decode_internal (*ctx, *state, state->batch , params.n_threads , false , params.abort_callback , params.abort_callback_user_data )) {
5788
5811
WHISPER_LOG_ERROR (" %s: failed to decode\n " , __func__);
5789
- return -7 ;
5812
+ return -8 ;
5790
5813
}
5791
5814
5792
5815
{
@@ -6086,7 +6109,7 @@ int whisper_full_with_state(
6086
6109
6087
6110
if (!whisper_decode_internal (*ctx, *state, state->batch , params.n_threads , false , params.abort_callback , params.abort_callback_user_data )) {
6088
6111
WHISPER_LOG_ERROR (" %s: failed to decode\n " , __func__);
6089
- return -8 ;
6112
+ return -9 ;
6090
6113
}
6091
6114
6092
6115
const int64_t t_start_sample_us = ggml_time_us ();
0 commit comments