Skip to content

Commit b91fe0f

Browse files
committed
Fixed VAD to work when using whisper_full_with_state
1 parent edea8a9 commit b91fe0f

File tree

1 file changed

+26
-35
lines changed

1 file changed

+26
-35
lines changed

src/whisper.cpp

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6652,8 +6652,8 @@ static bool whisper_vad(
66526652

66536653
if (vad_segments->data.size() > 0) {
66546654
state->has_vad_segments = true;
6655-
ctx->state->vad_segments.clear();
6656-
ctx->state->vad_segments.reserve(vad_segments->data.size());
6655+
state->vad_segments.clear();
6656+
state->vad_segments.reserve(vad_segments->data.size());
66576657

66586658
// Initialize the time mapping table
66596659
state->vad_mapping_table.clear();
@@ -6749,7 +6749,7 @@ static bool whisper_vad(
67496749

67506750
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
67516751
__func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0);
6752-
ctx->state->vad_segments.push_back(segment);
6752+
state->vad_segments.push_back(segment);
67536753

67546754
// Copy this speech segment
67556755
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
@@ -6820,6 +6820,24 @@ int whisper_full_with_state(
68206820
}
68216821
}
68226822

6823+
std::vector<float> vad_samples;
6824+
if (params.vad)
6825+
{
6826+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6827+
if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples))
6828+
{
6829+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
6830+
return -1;
6831+
}
6832+
if (vad_samples.empty())
6833+
{
6834+
state->result_all.clear();
6835+
return 0;
6836+
}
6837+
samples = vad_samples.data();
6838+
n_samples = vad_samples.size();
6839+
}
6840+
68236841
// auto-detect language if not specified
68246842
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
68256843
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
@@ -7720,25 +7738,11 @@ int whisper_full_with_state(
77207738
}
77217739

77227740
int whisper_full(
7723-
struct whisper_context * ctx,
7724-
struct whisper_full_params params,
7725-
const float * samples,
7726-
int n_samples) {
7727-
7728-
std::vector<float> vad_samples;
7729-
if (params.vad) {
7730-
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7731-
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7732-
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7733-
return -1;
7734-
}
7735-
if (vad_samples.empty()) {
7736-
ctx->state->result_all.clear();
7737-
return 0;
7738-
}
7739-
samples = vad_samples.data();
7740-
n_samples = vad_samples.size();
7741-
}
7741+
struct whisper_context *ctx,
7742+
struct whisper_full_params params,
7743+
const float *samples,
7744+
int n_samples)
7745+
{
77427746
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
77437747
}
77447748

@@ -7753,19 +7757,6 @@ int whisper_full_parallel(
77537757
return whisper_full(ctx, params, samples, n_samples);
77547758
}
77557759

7756-
std::vector<float> vad_samples;
7757-
if (params.vad) {
7758-
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7759-
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7760-
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7761-
return -1;
7762-
}
7763-
if (vad_samples.empty()) {
7764-
return 0;
7765-
}
7766-
samples = vad_samples.data();
7767-
n_samples = vad_samples.size();
7768-
}
77697760
int ret = 0;
77707761

77717762
// prepare separate states for each thread

0 commit comments

Comments
 (0)