Skip to content

Commit 0c2bfd5

Browse files
committed
bindings-ruby : enable setting of flash-attn
This commit enables the flast_attn context parameter to be set in the Ruby bindings. The motivation for this is that the default setting for this recently changed to true, which causes the following test to fail: ```console Failure: test_transcribe_n_processors(TestWhisper): </ask not what your country can do for you[,.] ask what you can do for your country/i> was expected to be =~ <" And so, my fellow Americans! Ask not what you do. what your country can do for you. Ask what you can do for your country.">. ``` An alternative could also be to change the test.
1 parent 2a56869 commit 0c2bfd5

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

bindings/ruby/ext/ruby_whisper_context.c

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,22 +118,34 @@ ruby_whisper_normalize_model_path(VALUE model_path)
118118
* new("base.en") -> Whisper::Context
119119
* new("path/to/model.bin") -> Whisper::Context
120120
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
121+
* new("base.en", flash_attn: false) -> Whisper::Context
121122
*/
122123
static VALUE
123124
ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
124125
{
125126
ruby_whisper *rw;
126-
VALUE whisper_model_file_path;
127+
VALUE whisper_model_file_path, options;
128+
VALUE flash_attn_opt;
127129

128130
// TODO: we can support init from buffer here too maybe another ruby object to expose
129-
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
131+
rb_scan_args(argc, argv, "01:", &whisper_model_file_path, &options);
130132
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
131133

132134
whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
133135
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
134136
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
135137
}
136-
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
138+
139+
struct whisper_context_params cparams = whisper_context_default_params();
140+
141+
if (!NIL_P(options)) {
142+
flash_attn_opt = rb_hash_aref(options, ID2SYM(rb_intern("flash_attn")));
143+
if (!NIL_P(flash_attn_opt)) {
144+
cparams.flash_attn = (flash_attn_opt == Qtrue);
145+
}
146+
}
147+
148+
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), cparams);
137149
if (rw->context == NULL) {
138150
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
139151
}

bindings/ruby/test/test_whisper.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_transcribe_non_parallel
3030
end
3131

3232
def test_transcribe_n_processors
33-
@whisper = Whisper::Context.new("base.en")
33+
@whisper = Whisper::Context.new("base.en", flash_attn: false)
3434
params = Whisper::Params.new
3535

3636
@whisper.transcribe(AUDIO, params, n_processors: 4) {|text|

0 commit comments

Comments
 (0)