diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index bc0c6e99194..1b89895796b 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -118,22 +118,34 @@ ruby_whisper_normalize_model_path(VALUE model_path) * new("base.en") -> Whisper::Context * new("path/to/model.bin") -> Whisper::Context * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context + * new("base.en", flash_attn: false) -> Whisper::Context */ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; - VALUE whisper_model_file_path; + VALUE whisper_model_file_path, options; + VALUE flash_attn_opt; // TODO: we can support init from buffer here too maybe another ruby object to expose - rb_scan_args(argc, argv, "01", &whisper_model_file_path); + rb_scan_args(argc, argv, "01:", &whisper_model_file_path, &options); TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path); if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } - rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); + + struct whisper_context_params cparams = whisper_context_default_params(); + + if (!NIL_P(options)) { + flash_attn_opt = rb_hash_aref(options, ID2SYM(rb_intern("flash_attn"))); + if (!NIL_P(flash_attn_opt)) { + cparams.flash_attn = (flash_attn_opt == Qtrue); + } + } + + rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), cparams); if (rw->context == NULL) { rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); } diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 12b82a8de09..9bc04062f52 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -30,7 +30,7 @@ def test_transcribe_non_parallel end def test_transcribe_n_processors - @whisper = Whisper::Context.new("base.en") + @whisper = Whisper::Context.new("base.en", flash_attn: false) params = Whisper::Params.new @whisper.transcribe(AUDIO, params, n_processors: 4) {|text|