Skip to content

Commit 608b9d3

Browse files
committed
whisper : enable flash attention by default
1 parent b57b9d3 commit 608b9d3

File tree

10 files changed

+103
-87
lines changed

10 files changed

+103
-87
lines changed

examples/addon.node/addon.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class ProgressWorker : public Napi::AsyncWorker {
207207
auto callback = [progress](Napi::Env env, Napi::Function jsCallback) {
208208
jsCallback.Call({Napi::Number::New(env, progress)});
209209
};
210-
210+
211211
tsfn.BlockingCall(callback);
212212
}
213213
}
@@ -396,59 +396,59 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
396396
std::string language = whisper_params.Get("language").As<Napi::String>();
397397
std::string model = whisper_params.Get("model").As<Napi::String>();
398398
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
399-
399+
400400
bool use_gpu = true;
401401
if (whisper_params.Has("use_gpu") && whisper_params.Get("use_gpu").IsBoolean()) {
402402
use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
403403
}
404-
404+
405405
bool flash_attn = false;
406406
if (whisper_params.Has("flash_attn") && whisper_params.Get("flash_attn").IsBoolean()) {
407407
flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
408408
}
409-
409+
410410
bool no_prints = false;
411411
if (whisper_params.Has("no_prints") && whisper_params.Get("no_prints").IsBoolean()) {
412412
no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
413413
}
414-
414+
415415
bool no_timestamps = false;
416416
if (whisper_params.Has("no_timestamps") && whisper_params.Get("no_timestamps").IsBoolean()) {
417417
no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
418418
}
419-
419+
420420
bool detect_language = false;
421421
if (whisper_params.Has("detect_language") && whisper_params.Get("detect_language").IsBoolean()) {
422422
detect_language = whisper_params.Get("detect_language").As<Napi::Boolean>();
423423
}
424-
424+
425425
int32_t audio_ctx = 0;
426426
if (whisper_params.Has("audio_ctx") && whisper_params.Get("audio_ctx").IsNumber()) {
427427
audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
428428
}
429-
429+
430430
bool comma_in_time = true;
431431
if (whisper_params.Has("comma_in_time") && whisper_params.Get("comma_in_time").IsBoolean()) {
432432
comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
433433
}
434-
434+
435435
int32_t max_len = 0;
436436
if (whisper_params.Has("max_len") && whisper_params.Get("max_len").IsNumber()) {
437437
max_len = whisper_params.Get("max_len").As<Napi::Number>();
438438
}
439-
439+
440440
// Add support for max_context
441441
int32_t max_context = -1;
442442
if (whisper_params.Has("max_context") && whisper_params.Get("max_context").IsNumber()) {
443443
max_context = whisper_params.Get("max_context").As<Napi::Number>();
444444
}
445-
445+
446446
// support prompt
447447
std::string prompt = "";
448448
if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) {
449449
prompt = whisper_params.Get("prompt").As<Napi::String>();
450450
}
451-
451+
452452
// Add support for print_progress
453453
bool print_progress = false;
454454
if (whisper_params.Has("print_progress") && whisper_params.Get("print_progress").IsBoolean()) {
@@ -465,37 +465,37 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
465465
if (whisper_params.Has("vad") && whisper_params.Get("vad").IsBoolean()) {
466466
vad = whisper_params.Get("vad").As<Napi::Boolean>();
467467
}
468-
468+
469469
std::string vad_model = "";
470470
if (whisper_params.Has("vad_model") && whisper_params.Get("vad_model").IsString()) {
471471
vad_model = whisper_params.Get("vad_model").As<Napi::String>();
472472
}
473-
473+
474474
float vad_threshold = 0.5f;
475475
if (whisper_params.Has("vad_threshold") && whisper_params.Get("vad_threshold").IsNumber()) {
476476
vad_threshold = whisper_params.Get("vad_threshold").As<Napi::Number>();
477477
}
478-
478+
479479
int vad_min_speech_duration_ms = 250;
480480
if (whisper_params.Has("vad_min_speech_duration_ms") && whisper_params.Get("vad_min_speech_duration_ms").IsNumber()) {
481481
vad_min_speech_duration_ms = whisper_params.Get("vad_min_speech_duration_ms").As<Napi::Number>();
482482
}
483-
483+
484484
int vad_min_silence_duration_ms = 100;
485485
if (whisper_params.Has("vad_min_silence_duration_ms") && whisper_params.Get("vad_min_silence_duration_ms").IsNumber()) {
486486
vad_min_silence_duration_ms = whisper_params.Get("vad_min_silence_duration_ms").As<Napi::Number>();
487487
}
488-
488+
489489
float vad_max_speech_duration_s = FLT_MAX;
490490
if (whisper_params.Has("vad_max_speech_duration_s") && whisper_params.Get("vad_max_speech_duration_s").IsNumber()) {
491491
vad_max_speech_duration_s = whisper_params.Get("vad_max_speech_duration_s").As<Napi::Number>();
492492
}
493-
493+
494494
int vad_speech_pad_ms = 30;
495495
if (whisper_params.Has("vad_speech_pad_ms") && whisper_params.Get("vad_speech_pad_ms").IsNumber()) {
496496
vad_speech_pad_ms = whisper_params.Get("vad_speech_pad_ms").As<Napi::Number>();
497497
}
498-
498+
499499
float vad_samples_overlap = 0.1f;
500500
if (whisper_params.Has("vad_samples_overlap") && whisper_params.Get("vad_samples_overlap").IsNumber()) {
501501
vad_samples_overlap = whisper_params.Get("vad_samples_overlap").As<Napi::Number>();

examples/bench/bench.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct whisper_params {
1313
std::string model = "models/ggml-base.en.bin";
1414

1515
bool use_gpu = true;
16-
bool flash_attn = false;
16+
bool flash_attn = true;
1717
};
1818

1919
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -26,11 +26,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
2626
whisper_print_usage(argc, argv, params);
2727
exit(0);
2828
}
29-
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
30-
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
31-
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
32-
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
33-
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
29+
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
30+
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
31+
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
32+
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
33+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
34+
else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; }
3435
else {
3536
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
3637
whisper_print_usage(argc, argv, params);
@@ -46,15 +47,16 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
4647
fprintf(stderr, "usage: %s [options]\n", argv[0]);
4748
fprintf(stderr, "\n");
4849
fprintf(stderr, "options:\n");
49-
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
50-
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
51-
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
52-
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
53-
fprintf(stderr, " %-7s 0 - whisper\n", "");
54-
fprintf(stderr, " %-7s 1 - memcpy\n", "");
55-
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
56-
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
57-
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
50+
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
51+
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
52+
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
53+
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
54+
fprintf(stderr, " %-7s 0 - whisper\n", "");
55+
fprintf(stderr, " %-7s 1 - memcpy\n", "");
56+
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
57+
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
58+
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
59+
fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true");
5860
fprintf(stderr, "\n");
5961
}
6062

examples/cli/cli.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ struct whisper_params {
7575
bool no_timestamps = false;
7676
bool log_score = false;
7777
bool use_gpu = true;
78-
bool flash_attn = false;
78+
bool flash_attn = true;
7979
bool suppress_nst = false;
8080

8181
std::string language = "en";
@@ -193,6 +193,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
193193
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
194194
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
195195
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
196+
else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; }
196197
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
197198
else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; }
198199
else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
@@ -271,7 +272,8 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
271272
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
272273
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
273274
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
274-
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
275+
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
276+
fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true");
275277
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
276278
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
277279
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());

examples/command/command.cpp

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct whisper_params {
4242
bool print_energy = false;
4343
bool no_timestamps = true;
4444
bool use_gpu = true;
45-
bool flash_attn = false;
45+
bool flash_attn = true;
4646

4747
std::string language = "en";
4848
std::string model = "models/ggml-base.en.bin";
@@ -66,28 +66,29 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
6666
whisper_print_usage(argc, argv, params);
6767
exit(0);
6868
}
69-
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
70-
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
71-
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
72-
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
73-
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
74-
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
75-
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
76-
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
77-
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
78-
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
79-
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
80-
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
81-
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
82-
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
83-
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
84-
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
85-
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
86-
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
87-
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
88-
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
89-
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
90-
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
69+
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
70+
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
71+
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
72+
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
73+
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
74+
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
75+
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
76+
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
77+
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
78+
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
79+
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
80+
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
81+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
82+
else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; }
83+
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
84+
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
85+
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
86+
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
87+
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
88+
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
89+
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
90+
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
91+
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
9192
else {
9293
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
9394
whisper_print_usage(argc, argv, params);
@@ -116,7 +117,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
116117
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
117118
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
118119
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
119-
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
120+
fprintf(stderr, " -fa, --flash-attn [%-7s] enbale flash attention\n", params.flash_attn ? "true" : "false");
121+
fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true");
120122
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
121123
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
122124
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());

0 commit comments

Comments
 (0)