diff --git a/bindings/go/README.md b/bindings/go/README.md index 6958ede80f2..cbd2a622874 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -31,7 +31,7 @@ func main() { if err != nil { panic(err) } - if err := context.Process(samples, nil, nil); err != nil { + if err := context.Process(samples, nil, nil, nil); err != nil { return err } diff --git a/bindings/go/examples/go-whisper/process.go b/bindings/go/examples/go-whisper/process.go index 71e52f01000..833947e843c 100644 --- a/bindings/go/examples/go-whisper/process.go +++ b/bindings/go/examples/go-whisper/process.go @@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error { // Process the data fmt.Fprintf(flags.Output(), " ...processing %q\n", path) context.ResetTimings() - if err := context.Process(data, cb, nil); err != nil { + if err := context.Process(data, nil, cb, nil); err != nil { return err } diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 06376b1b870..3d4bbe98945 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f // Process new sample data and return any errors func (context *context) Process( data []float32, + callEncoderBegin EncoderBeginCallback, callNewSegment SegmentCallback, callProgress ProgressCallback, ) error { @@ -203,7 +204,20 @@ func (context *context) Process( // We don't do parallel processing at the moment processors := 0 if processors > 1 { - if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { + if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin, + func(new int) { + if callNewSegment != nil { + num_segments := context.model.ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + callNewSegment(toSegment(context.model.ctx, i)) + } + } + }); err != nil { + return err + } + } else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin, + func(new int) { if callNewSegment != nil { num_segments := context.model.ctx.Whisper_full_n_segments() s0 := num_segments - new @@ -211,22 +225,11 @@ func (context *context) Process( callNewSegment(toSegment(context.model.ctx, i)) } } - }); err != nil { - return err - } - } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { - if callNewSegment != nil { - num_segments := context.model.ctx.Whisper_full_n_segments() - s0 := num_segments - new - for i := s0; i < num_segments; i++ { - callNewSegment(toSegment(context.model.ctx, i)) + }, func(progress int) { + if callProgress != nil { + callProgress(progress) } - } - }, func(progress int) { - if callProgress != nil { - callProgress(progress) - } - }); err != nil { + }); err != nil { return err } @@ -312,6 +315,10 @@ func (context *context) IsLANG(t Token, lang string) bool { } } +func (context *context) GetDetectedLanguage() string { + return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id()) +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 8981b1a8116..1721215317e 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -16,6 +16,10 @@ type SegmentCallback func(Segment) // processing. It is called during the Process function type ProgressCallback func(int) +// EncoderBeginCallback is the callback function for checking if we want to +// continue processing. It is called during the Process function +type EncoderBeginCallback func() bool + // Model is the interface to a whisper model. Create a new model with the // function whisper.New(string) type Model interface { @@ -31,12 +35,14 @@ type Model interface { Languages() []string } -// Context is the speach recognition context. +// Context is the speech recognition context. type Context interface { - SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. - SetTranslate(bool) // Set translate flag - IsMultilingual() bool // Return true if the model is multilingual. - Language() string // Get language + SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. + SetTranslate(bool) // Set translate flag + IsMultilingual() bool // Return true if the model is multilingual. + Language() string // Get language + GetDetectedLanguage() string // Get auto detected language + SetOffset(time.Duration) // Set offset SetDuration(time.Duration) // Set duration @@ -58,7 +64,7 @@ type Context interface { // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the // callback function during processing. - Process([]float32, SegmentCallback, ProgressCallback) error + Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error // After process is called, return segments until the end of the stream // is reached, when io.EOF is returned. diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 68a150223c7..f3631c0e990 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -87,7 +87,7 @@ func (model *model) NewContext() (Context, error) { } // Create new context - params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) + params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_BEAM_SEARCH) params.SetTranslate(false) params.SetPrintSpecial(false) params.SetPrintProgress(false) diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 39ec43b47ed..cdae877ebec 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -296,7 +296,13 @@ func Whisper_print_system_info() string { // Return default parameters for a strategy func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params { // Get default parameters - return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy))) + p := Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy))) + + p.greedy.best_of = 5 + p.thold_pt = 0 + p.thold_ptsum = 0 + + return p } // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text