mirror of
https://github.com/royshil/obs-localvocal.git
synced 2026-01-08 20:08:08 -05:00
Upgrade silero vad v5 (and some other changes) (#148)
* Add accessor for VAD window size in samples * Feed buffered audio data to VAD in proper window sizes * Wake whisper thread whenever audio is received * Update silero VAD to v5 * Only reset VAD state between chunks of activity
This commit is contained in:
Binary file not shown.
@@ -48,6 +48,7 @@ struct transcription_filter_data {
|
||||
|
||||
/* Resampler */
|
||||
audio_resampler_t *resampler_to_whisper;
|
||||
struct circlebuf resampled_buffer;
|
||||
|
||||
/* whisper */
|
||||
std::string whisper_model_path;
|
||||
|
||||
@@ -108,6 +108,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
|
||||
// calculate timestamp offset from the start of the stream
|
||||
info.timestamp_offset_ns = now_ns() - gf->start_timestamp_ms * 1000000;
|
||||
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
|
||||
gf->wshiper_thread_cv.notify_one();
|
||||
}
|
||||
|
||||
return audio;
|
||||
@@ -154,6 +155,8 @@ void transcription_filter_destroy(void *data)
|
||||
}
|
||||
circlebuf_free(&gf->info_buffer);
|
||||
|
||||
circlebuf_free(&gf->resampled_buffer);
|
||||
|
||||
if (gf->captions_monitor.isEnabled()) {
|
||||
gf->captions_monitor.stopThread();
|
||||
}
|
||||
@@ -443,6 +446,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
}
|
||||
circlebuf_init(&gf->info_buffer);
|
||||
circlebuf_init(&gf->whisper_buffer);
|
||||
circlebuf_init(&gf->resampled_buffer);
|
||||
|
||||
// allocate copy buffers
|
||||
gf->copy_buffers[0] =
|
||||
|
||||
@@ -92,11 +92,11 @@ void VadIterator::init_onnx_model(const SileroString &model_path)
|
||||
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||
};
|
||||
|
||||
void VadIterator::reset_states(bool reset_hc)
|
||||
void VadIterator::reset_states(bool reset_state)
|
||||
{
|
||||
if (reset_hc) {
|
||||
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
|
||||
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
|
||||
if (reset_state) {
|
||||
// Call reset before each audio start
|
||||
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
|
||||
triggered = false;
|
||||
}
|
||||
temp_end = 0;
|
||||
@@ -115,19 +115,16 @@ float VadIterator::predict_one(const std::vector<float> &data)
|
||||
input.assign(data.begin(), data.end());
|
||||
Ort::Value input_ort = Ort::Value::CreateTensor<float>(memory_info, input.data(),
|
||||
input.size(), input_node_dims, 2);
|
||||
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
||||
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(memory_info, sr.data(), sr.size(),
|
||||
sr_node_dims, 1);
|
||||
Ort::Value h_ort =
|
||||
Ort::Value::CreateTensor<float>(memory_info, _h.data(), _h.size(), hc_node_dims, 3);
|
||||
Ort::Value c_ort =
|
||||
Ort::Value::CreateTensor<float>(memory_info, _c.data(), _c.size(), hc_node_dims, 3);
|
||||
|
||||
// Clear and add inputs
|
||||
ort_inputs.clear();
|
||||
ort_inputs.emplace_back(std::move(input_ort));
|
||||
ort_inputs.emplace_back(std::move(state_ort));
|
||||
ort_inputs.emplace_back(std::move(sr_ort));
|
||||
ort_inputs.emplace_back(std::move(h_ort));
|
||||
ort_inputs.emplace_back(std::move(c_ort));
|
||||
|
||||
// Infer
|
||||
ort_outputs = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(),
|
||||
@@ -136,10 +133,8 @@ float VadIterator::predict_one(const std::vector<float> &data)
|
||||
|
||||
// Output probability & update h,c recursively
|
||||
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
|
||||
float *hn = ort_outputs[1].GetTensorMutableData<float>();
|
||||
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
|
||||
float *cn = ort_outputs[2].GetTensorMutableData<float>();
|
||||
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
|
||||
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
|
||||
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
||||
|
||||
return speech_prob;
|
||||
}
|
||||
@@ -264,9 +259,9 @@ void VadIterator::predict(const std::vector<float> &data)
|
||||
}
|
||||
};
|
||||
|
||||
void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)
|
||||
void VadIterator::process(const std::vector<float> &input_wav, bool reset_state)
|
||||
{
|
||||
reset_states(reset_hc);
|
||||
reset_states(reset_state);
|
||||
|
||||
audio_length_samples = (int)input_wav.size();
|
||||
|
||||
@@ -290,7 +285,7 @@ void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)
|
||||
|
||||
void VadIterator::process(const std::vector<float> &input_wav, std::vector<float> &output_wav)
|
||||
{
|
||||
process(input_wav, true);
|
||||
process(input_wav);
|
||||
collect_chunks(input_wav, output_wav);
|
||||
}
|
||||
|
||||
@@ -352,8 +347,7 @@ VadIterator::VadIterator(const SileroString &ModelPath, int Sample_rate, int win
|
||||
input_node_dims[0] = 1;
|
||||
input_node_dims[1] = window_size_samples;
|
||||
|
||||
_h.resize(size_hc);
|
||||
_c.resize(size_hc);
|
||||
_state.resize(size_state);
|
||||
sr.resize(1);
|
||||
sr[0] = sample_rate;
|
||||
};
|
||||
|
||||
@@ -43,18 +43,20 @@ private:
|
||||
private:
|
||||
void init_engine_threads(int inter_threads, int intra_threads);
|
||||
void init_onnx_model(const SileroString &model_path);
|
||||
void reset_states(bool reset_hc);
|
||||
void reset_states(bool reset_state);
|
||||
float predict_one(const std::vector<float> &data);
|
||||
void predict(const std::vector<float> &data);
|
||||
|
||||
public:
|
||||
void process(const std::vector<float> &input_wav, bool reset_hc = true);
|
||||
void process(const std::vector<float> &input_wav, bool reset_state = true);
|
||||
void process(const std::vector<float> &input_wav, std::vector<float> &output_wav);
|
||||
void collect_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
|
||||
const std::vector<timestamp_t> get_speech_timestamps() const;
|
||||
void drop_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
|
||||
void set_threshold(float threshold_) { this->threshold = threshold_; }
|
||||
|
||||
int64_t get_window_size_samples() const { return window_size_samples; }
|
||||
|
||||
private:
|
||||
// model config
|
||||
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
|
||||
@@ -84,27 +86,26 @@ private:
|
||||
// Inputs
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
|
||||
std::vector<const char *> input_node_names = {"input", "sr", "h", "c"};
|
||||
std::vector<const char *> input_node_names = {"input", "state", "sr"};
|
||||
std::vector<float> input;
|
||||
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
|
||||
std::vector<float> _state;
|
||||
std::vector<int64_t> sr;
|
||||
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
|
||||
std::vector<float> _h;
|
||||
std::vector<float> _c;
|
||||
|
||||
int64_t input_node_dims[2] = {};
|
||||
const int64_t state_node_dims[3] = {2, 1, 128};
|
||||
const int64_t sr_node_dims[1] = {1};
|
||||
const int64_t hc_node_dims[3] = {2, 1, 64};
|
||||
|
||||
// Outputs
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
std::vector<const char *> output_node_names = {"output", "hn", "cn"};
|
||||
std::vector<const char *> output_node_names = {"output", "stateN"};
|
||||
|
||||
public:
|
||||
// Construction
|
||||
VadIterator(const SileroString &ModelPath, int Sample_rate = 16000,
|
||||
int windows_frame_size = 64, float Threshold = 0.5,
|
||||
int min_silence_duration_ms = 0, int speech_pad_ms = 64,
|
||||
int min_speech_duration_ms = 64,
|
||||
int windows_frame_size = 32, float Threshold = 0.5,
|
||||
int min_silence_duration_ms = 0, int speech_pad_ms = 32,
|
||||
int min_speech_duration_ms = 32,
|
||||
float max_speech_duration_s = std::numeric_limits<float>::infinity());
|
||||
|
||||
// Default constructor
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
#include <util/profiler.hpp>
|
||||
|
||||
#include "plugin-support.h"
|
||||
#include "transcription-filter-data.h"
|
||||
#include "whisper-processing.h"
|
||||
@@ -392,22 +394,43 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
|
||||
num_frames_from_infos, overlap_size);
|
||||
gf->last_num_frames = num_frames_from_infos + overlap_size;
|
||||
|
||||
// resample to 16kHz
|
||||
float *resampled_16khz[MAX_PREPROC_CHANNELS];
|
||||
uint32_t resampled_16khz_frames;
|
||||
uint64_t ts_offset;
|
||||
audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz,
|
||||
&resampled_16khz_frames, &ts_offset,
|
||||
(const uint8_t **)gf->copy_buffers,
|
||||
(uint32_t)num_frames_from_infos);
|
||||
{
|
||||
// resample to 16kHz
|
||||
float *resampled_16khz[MAX_PREPROC_CHANNELS];
|
||||
uint32_t resampled_16khz_frames;
|
||||
uint64_t ts_offset;
|
||||
{
|
||||
ProfileScope("resample");
|
||||
audio_resampler_resample(gf->resampler_to_whisper,
|
||||
(uint8_t **)resampled_16khz,
|
||||
&resampled_16khz_frames, &ts_offset,
|
||||
(const uint8_t **)gf->copy_buffers,
|
||||
(uint32_t)num_frames_from_infos);
|
||||
}
|
||||
|
||||
obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms", (int)gf->channels,
|
||||
(int)resampled_16khz_frames,
|
||||
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
|
||||
obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms",
|
||||
(int)gf->channels, (int)resampled_16khz_frames,
|
||||
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
|
||||
circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0],
|
||||
resampled_16khz_frames * sizeof(float));
|
||||
}
|
||||
|
||||
std::vector<float> vad_input(resampled_16khz[0],
|
||||
resampled_16khz[0] + resampled_16khz_frames);
|
||||
gf->vad->process(vad_input, false);
|
||||
if (gf->resampled_buffer.size < (gf->vad->get_window_size_samples() * sizeof(float)))
|
||||
return last_vad_state;
|
||||
|
||||
size_t len =
|
||||
gf->resampled_buffer.size / (gf->vad->get_window_size_samples() * sizeof(float));
|
||||
|
||||
std::vector<float> vad_input;
|
||||
vad_input.resize(len * gf->vad->get_window_size_samples());
|
||||
circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(),
|
||||
vad_input.size() * sizeof(float));
|
||||
|
||||
obs_log(gf->log_level, "sending %d frames to vad", vad_input.size());
|
||||
{
|
||||
ProfileScope("vad->process");
|
||||
gf->vad->process(vad_input, !last_vad_state.vad_on);
|
||||
}
|
||||
|
||||
const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000;
|
||||
const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000;
|
||||
@@ -417,8 +440,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
|
||||
|
||||
std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
|
||||
if (stamps.size() == 0) {
|
||||
obs_log(gf->log_level, "VAD detected no speech in %d frames",
|
||||
resampled_16khz_frames);
|
||||
obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size());
|
||||
if (last_vad_state.vad_on) {
|
||||
obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference");
|
||||
run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms,
|
||||
@@ -428,7 +450,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
|
||||
}
|
||||
|
||||
if (gf->enable_audio_chunks_callback) {
|
||||
audio_chunk_callback(gf, resampled_16khz[0], resampled_16khz_frames,
|
||||
audio_chunk_callback(gf, vad_input.data(), vad_input.size(),
|
||||
VAD_STATE_IS_OFF,
|
||||
{DETECTION_RESULT_SILENCE,
|
||||
"[silence]",
|
||||
@@ -452,16 +474,16 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
|
||||
}
|
||||
|
||||
int end_frame = stamps[i].end;
|
||||
if (i == stamps.size() - 1 && stamps[i].end < (int)resampled_16khz_frames) {
|
||||
if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) {
|
||||
// take at least 100ms of audio after the last speech segment, if available
|
||||
end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10,
|
||||
(int)resampled_16khz_frames);
|
||||
(int)vad_input.size());
|
||||
}
|
||||
|
||||
const int number_of_frames = end_frame - start_frame;
|
||||
|
||||
// push the data into gf-whisper_buffer
|
||||
circlebuf_push_back(&gf->whisper_buffer, resampled_16khz[0] + start_frame,
|
||||
circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame,
|
||||
number_of_frames * sizeof(float));
|
||||
|
||||
obs_log(gf->log_level,
|
||||
@@ -472,7 +494,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
|
||||
gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE);
|
||||
|
||||
// segment "end" is in the middle of the buffer, send it to inference
|
||||
if (stamps[i].end < (int)resampled_16khz_frames) {
|
||||
if (stamps[i].end < (int)vad_input.size()) {
|
||||
// new "ending" segment (not up to the end of the buffer)
|
||||
obs_log(gf->log_level, "VAD segment end -> send to inference");
|
||||
// find the end timestamp of the segment
|
||||
@@ -545,30 +567,24 @@ void whisper_loop(void *data)
|
||||
obs_log(gf->log_level, "Starting whisper thread");
|
||||
|
||||
vad_state current_vad_state = {false, 0, 0, 0};
|
||||
// 500 ms worth of audio is needed for VAD segmentation
|
||||
uint32_t min_num_bytes_for_vad = (gf->sample_rate / 2) * sizeof(float);
|
||||
|
||||
const char *whisper_loop_name = "Whisper loop";
|
||||
profile_register_root(whisper_loop_name, 50 * 1000 * 1000);
|
||||
|
||||
// Thread main loop
|
||||
while (true) {
|
||||
ProfileScope(whisper_loop_name);
|
||||
{
|
||||
ProfileScope("lock whisper ctx");
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
ProfileScope("locked whisper ctx");
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_WARNING, "Whisper context is null, exiting thread");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t num_bytes_on_input = 0;
|
||||
{
|
||||
// scoped lock the buffer mutex
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
|
||||
num_bytes_on_input = (uint32_t)gf->input_buffers[0].size;
|
||||
}
|
||||
|
||||
// only run vad segmentation if there are at least 500 ms of audio in the buffer
|
||||
if (num_bytes_on_input > min_num_bytes_for_vad) {
|
||||
current_vad_state = vad_based_segmentation(gf, current_vad_state);
|
||||
}
|
||||
current_vad_state = vad_based_segmentation(gf, current_vad_state);
|
||||
|
||||
if (!gf->cleared_last_sub) {
|
||||
// check if we should clear the current sub depending on the minimum subtitle duration
|
||||
|
||||
Reference in New Issue
Block a user