Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct whisper_params {
int32_t max_len = 0;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
int32_t max_decoders = whisper_context_default_params().max_decoders;
int32_t audio_ctx = 0;

float word_thold = 0.01f;
Expand Down Expand Up @@ -131,6 +132,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-md" || arg == "--max-decoders") { params.max_decoders = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
Expand Down Expand Up @@ -198,6 +200,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -md N, --max-decoders N [%-7d] Max decoders, used to set the text context cache factor\n", params.max_decoders);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
Expand Down Expand Up @@ -981,6 +984,7 @@ int main(int argc, char ** argv) {

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
cparams.max_decoders = params.max_decoders;

if (!params.dtw.empty()) {
cparams.dtw_token_timestamps = true;
Expand Down
2 changes: 2 additions & 0 deletions include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ extern "C" {
struct whisper_aheads dtw_aheads;

size_t dtw_mem_size; // TODO: remove

int max_decoders; // to be used to setup text context factor
};

typedef struct whisper_token_data {
Expand Down
12 changes: 8 additions & 4 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ static bool whisper_kv_cache_find_slot(
}

if (n_tested >= n_ctx) {
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens. n_tested=%d n_ctx=%d cache.head=%d\n", __func__, n_tokens, n_tested, n_ctx, cache.head);
return false;
}
}
Expand Down Expand Up @@ -3408,9 +3408,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
}

// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
// in theory, there can be a case where this is not enough, but in practice it should always be enough
const int factor = 3;
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx (default value)
// Note: there are cases where 3 is not enough specially when increasing beamsize
const int factor = ctx->params.max_decoders;

WHISPER_LOG_DEBUG("%s: init self-attn cache: n_ctx: %d factor: %d\n", __func__, factor*ctx->model.hparams.n_text_ctx, factor);

if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
ctx->model.hparams.n_text_state,
Expand Down Expand Up @@ -3635,6 +3637,7 @@ struct whisper_context_params whisper_context_default_params() {
/*.heads =*/ NULL,
},
/*.dtw_mem_size =*/ 1024*1024*128,
/* max_decoders =*/ 3
};
return result;
}
Expand Down Expand Up @@ -3732,6 +3735,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
WHISPER_LOG_INFO("%s: max-decoders = %d\n", __func__, params.max_decoders);

whisper_context * ctx = new whisper_context;
ctx->params = params;
Expand Down
Loading