From 77114b733b92adabc634f53a24943bb15577b7b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Sep 2025 13:23:43 +0300 Subject: [PATCH] whisper : enable flash attention by default --- examples/addon.node/addon.cpp | 38 ++++++++--------- examples/bench/bench.cpp | 32 ++++++++------- examples/cli/cli.cpp | 6 ++- examples/command/command.cpp | 50 ++++++++++++----------- examples/lsp/lsp.cpp | 36 ++++++++-------- examples/server/server.cpp | 6 ++- examples/stream/stream.cpp | 6 ++- examples/talk-llama/talk-llama.cpp | 8 ++-- examples/wchess/wchess.cmd/wchess.cmd.cpp | 6 ++- src/whisper.cpp | 2 +- 10 files changed, 103 insertions(+), 87 deletions(-) diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 952e44e3ce7..71f65b0423c 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -207,7 +207,7 @@ class ProgressWorker : public Napi::AsyncWorker { auto callback = [progress](Napi::Env env, Napi::Function jsCallback) { jsCallback.Call({Napi::Number::New(env, progress)}); }; - + tsfn.BlockingCall(callback); } } @@ -396,59 +396,59 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { std::string language = whisper_params.Get("language").As(); std::string model = whisper_params.Get("model").As(); std::string input = whisper_params.Get("fname_inp").As(); - + bool use_gpu = true; if (whisper_params.Has("use_gpu") && whisper_params.Get("use_gpu").IsBoolean()) { use_gpu = whisper_params.Get("use_gpu").As(); } - + bool flash_attn = false; if (whisper_params.Has("flash_attn") && whisper_params.Get("flash_attn").IsBoolean()) { flash_attn = whisper_params.Get("flash_attn").As(); } - + bool no_prints = false; if (whisper_params.Has("no_prints") && whisper_params.Get("no_prints").IsBoolean()) { no_prints = whisper_params.Get("no_prints").As(); } - + bool no_timestamps = false; if (whisper_params.Has("no_timestamps") && whisper_params.Get("no_timestamps").IsBoolean()) { no_timestamps = whisper_params.Get("no_timestamps").As(); } - + bool detect_language = false; if (whisper_params.Has("detect_language") && whisper_params.Get("detect_language").IsBoolean()) { detect_language = whisper_params.Get("detect_language").As(); } - + int32_t audio_ctx = 0; if (whisper_params.Has("audio_ctx") && whisper_params.Get("audio_ctx").IsNumber()) { audio_ctx = whisper_params.Get("audio_ctx").As(); } - + bool comma_in_time = true; if (whisper_params.Has("comma_in_time") && whisper_params.Get("comma_in_time").IsBoolean()) { comma_in_time = whisper_params.Get("comma_in_time").As(); } - + int32_t max_len = 0; if (whisper_params.Has("max_len") && whisper_params.Get("max_len").IsNumber()) { max_len = whisper_params.Get("max_len").As(); } - + // Add support for max_context int32_t max_context = -1; if (whisper_params.Has("max_context") && whisper_params.Get("max_context").IsNumber()) { max_context = whisper_params.Get("max_context").As(); } - + // support prompt std::string prompt = ""; if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) { prompt = whisper_params.Get("prompt").As(); } - + // Add support for print_progress bool print_progress = false; if (whisper_params.Has("print_progress") && whisper_params.Get("print_progress").IsBoolean()) { @@ -465,37 +465,37 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { if (whisper_params.Has("vad") && whisper_params.Get("vad").IsBoolean()) { vad = whisper_params.Get("vad").As(); } - + std::string vad_model = ""; if (whisper_params.Has("vad_model") && whisper_params.Get("vad_model").IsString()) { vad_model = whisper_params.Get("vad_model").As(); } - + float vad_threshold = 0.5f; if (whisper_params.Has("vad_threshold") && whisper_params.Get("vad_threshold").IsNumber()) { vad_threshold = whisper_params.Get("vad_threshold").As(); } - + int vad_min_speech_duration_ms = 250; if (whisper_params.Has("vad_min_speech_duration_ms") && whisper_params.Get("vad_min_speech_duration_ms").IsNumber()) { vad_min_speech_duration_ms = whisper_params.Get("vad_min_speech_duration_ms").As(); } - + int vad_min_silence_duration_ms = 100; if (whisper_params.Has("vad_min_silence_duration_ms") && whisper_params.Get("vad_min_silence_duration_ms").IsNumber()) { vad_min_silence_duration_ms = whisper_params.Get("vad_min_silence_duration_ms").As(); } - + float vad_max_speech_duration_s = FLT_MAX; if (whisper_params.Has("vad_max_speech_duration_s") && whisper_params.Get("vad_max_speech_duration_s").IsNumber()) { vad_max_speech_duration_s = whisper_params.Get("vad_max_speech_duration_s").As(); } - + int vad_speech_pad_ms = 30; if (whisper_params.Has("vad_speech_pad_ms") && whisper_params.Get("vad_speech_pad_ms").IsNumber()) { vad_speech_pad_ms = whisper_params.Get("vad_speech_pad_ms").As(); } - + float vad_samples_overlap = 0.1f; if (whisper_params.Has("vad_samples_overlap") && whisper_params.Get("vad_samples_overlap").IsNumber()) { vad_samples_overlap = whisper_params.Get("vad_samples_overlap").As(); diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 36d56769289..2d967f2caf4 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -13,7 +13,7 @@ struct whisper_params { std::string model = "models/ggml-base.en.bin"; bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -26,11 +26,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -46,15 +47,16 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); - fprintf(stderr, " %-7s 0 - whisper\n", ""); - fprintf(stderr, " %-7s 1 - memcpy\n", ""); - fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); - fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); + fprintf(stderr, " %-7s 0 - whisper\n", ""); + fprintf(stderr, " %-7s 1 - memcpy\n", ""); + fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, "\n"); } diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index f73ed9ae078..457a1ff35c2 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -75,7 +75,7 @@ struct whisper_params { bool no_timestamps = false; bool log_score = false; bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; bool suppress_nst = false; std::string language = "en"; @@ -193,6 +193,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; } else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; } @@ -271,7 +272,8 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 0f87710cefa..ff7c037417f 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -42,7 +42,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -66,28 +66,29 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); } - else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); } - else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } - else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } - else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } - else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } - else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; } - else if ( arg == "--grammar") { params.grammar = argv[++i]; } - else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } - else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); } + else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); } + else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } + else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } + else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } + else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } + else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; } + else if ( arg == "--grammar") { params.grammar = argv[++i]; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -116,7 +117,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enbale flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); diff --git a/examples/lsp/lsp.cpp b/examples/lsp/lsp.cpp index cf8b75e7a29..cf47f130c95 100644 --- a/examples/lsp/lsp.cpp +++ b/examples/lsp/lsp.cpp @@ -31,7 +31,7 @@ struct whisper_params { bool print_special = false; bool print_energy = false; bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -62,21 +62,22 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); } - else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); } - else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } - else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } - else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); } + else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); } + else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } + else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } + else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -105,7 +106,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 901f65f6c35..fd9b7784108 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -101,7 +101,7 @@ struct whisper_params { bool print_progress = false; bool no_timestamps = false; bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; bool suppress_nst = false; bool no_context = false; bool no_language_probabilities = false; @@ -178,7 +178,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); @@ -236,6 +237,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; } diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 37b23886821..94f9016e075 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -36,7 +36,7 @@ struct whisper_params { bool tinydiarize = false; bool save_audio = false; // save audio to wav file bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -74,6 +74,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -111,7 +112,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention during inference\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention during inference\n", params.flash_attn ? "false" : "true"); fprintf(stderr, "\n"); } diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 239c56902d4..e98ca64035f 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -66,7 +66,7 @@ struct whisper_params { float top_p = 0.80f; float min_p = 0.01f; float temp = 0.30f; - + float vad_thold = 0.6f; float freq_thold = 100.0f; @@ -76,7 +76,7 @@ struct whisper_params { bool no_timestamps = true; bool verbose_prompt = false; bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; std::string person = "Georgi"; std::string bot_name = "LLaMA"; @@ -122,6 +122,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; } else if (arg == "--session") { params.path_session = argv[++i]; } @@ -175,7 +176,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str()); fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str()); diff --git a/examples/wchess/wchess.cmd/wchess.cmd.cpp b/examples/wchess/wchess.cmd/wchess.cmd.cpp index 816eb1b3c95..8673d13d052 100644 --- a/examples/wchess/wchess.cmd/wchess.cmd.cpp +++ b/examples/wchess/wchess.cmd/wchess.cmd.cpp @@ -31,7 +31,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; - bool flash_attn = false; + bool flash_attn = true; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -60,7 +60,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during decoding\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention during decoding\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention during decoding\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -92,6 +93,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } diff --git a/src/whisper.cpp b/src/whisper.cpp index efc3192b47c..d99dd7be68c 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -3592,7 +3592,7 @@ int whisper_ctx_init_openvino_encoder( struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { /*.use_gpu =*/ true, - /*.flash_attn =*/ false, + /*.flash_attn =*/ true, /*.gpu_device =*/ 0, /*.dtw_token_timestamps =*/ false,