Overlap analysis (#92)

* Update buffer size and overlap size in whisper-processing.h and default buffer size in msec in transcription-filter.cpp

* Update buffer size and overlap size in whisper-processing.h and default buffer size in msec in transcription-filter.cpp

* Update suppress_sentences in en-US.ini and transcription-filter-data.h

* Update suppress_sentences and fix whitespace in transcription-filter-data.h, whisper-processing.h, transcription-utils.cpp, and transcription-filter.h

* Update whisper-processing.cpp and whisper-utils.cpp files

* Update findStartOfOverlap function signature to use int instead of size_t

* Update Whispercpp_Build_GIT_TAG to use commit 7395c70a748753e3800b63e3422a2b558a097c80 in BuildWhispercpp.cmake

* Update buffer size and overlap size in whisper-processing.h and default buffer size in msec in transcription-filter.cpp

* Update unused parameter in transcription-filter-properties function

* Update log level and add suppress_sentences feature in transcription-filter.cpp and whisper-processing.cpp

* Add translation output feature in en-US.ini and transcription-filter-data.h

* Add DTW token timestamps and buffered output feature

* trigger rebuild

* Refactor remove_leading_trailing_nonalpha function to improve readability and performance

* Refactor is_lead_byte and is_trail_byte macros for improved readability and maintainability

* Refactor is_lead_byte and is_trail_byte macros for improved readability and maintainability

* trigger build
This commit is contained in:
Roy Shilkrot
2024-04-25 17:14:13 -04:00
committed by GitHub
parent 65da380f9f
commit ab1b74a35c
16 changed files with 715 additions and 387 deletions

View File

@@ -86,12 +86,14 @@ target_sources(
PRIVATE src/plugin-main.c
src/transcription-filter.cpp
src/transcription-filter.c
src/transcription-utils.cpp
src/model-utils/model-downloader.cpp
src/model-utils/model-downloader-ui.cpp
src/model-utils/model-infos.cpp
src/whisper-utils/whisper-processing.cpp
src/whisper-utils/whisper-utils.cpp
src/whisper-utils/silero-vad-onnx.cpp
src/whisper-utils/token-buffer-thread.cpp
src/translation/translation.cpp
src/utils.cpp)

View File

@@ -107,12 +107,12 @@ elseif(WIN32)
install(FILES ${WHISPER_DLLS} DESTINATION "obs-plugins/64bit")
else()
set(Whispercpp_Build_GIT_TAG "f22d27a385d34b1e544031efe8aa2e3d73922791")
set(Whispercpp_Build_GIT_TAG "7395c70a748753e3800b63e3422a2b558a097c80")
set(WHISPER_EXTRA_CXX_FLAGS "-fPIC")
set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_BLAS=OFF -DWHISPER_CUBLAS=OFF -DWHISPER_OPENBLAS=OFF -DWHISPER_NO_AVX=ON
-DWHISPER_NO_AVX2=ON)
# On Linux and MacOS build a static Whisper library
# On Linux build a static Whisper library
ExternalProject_Add(
Whispercpp_Build
DOWNLOAD_EXTRACT_TIMESTAMP true
@@ -133,7 +133,7 @@ else()
ExternalProject_Get_Property(Whispercpp_Build INSTALL_DIR)
# on Linux and MacOS add the static Whisper library to the link line
# add the static Whisper library to the link line
add_library(Whispercpp::Whisper STATIC IMPORTED)
set_target_properties(
Whispercpp::Whisper

View File

@@ -51,3 +51,7 @@ translate_add_context="Translate with context"
whisper_translate="Translate to English (Whisper)"
buffer_size_msec="Buffer size (ms)"
overlap_size_msec="Overlap size (ms)"
suppress_sentences="Suppress sentences (each line)"
translate_output="Translation output"
dtw_token_timestamps="DTW token timestamps"
buffered_output="Buffered output (Experimental)"

View File

@@ -1,118 +0,0 @@
#ifndef CAPTIONS_THREAD_H
#define CAPTIONS_THREAD_H
#include <queue>
#include <vector>
#include <chrono>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
#include <string>
#include <obs.h>
#include "plugin-support.h"
class CaptionMonitor {
public:
// default constructor
CaptionMonitor() = default;
~CaptionMonitor()
{
{
std::lock_guard<std::mutex> lock(queueMutex);
stop = true;
}
condVar.notify_all();
workerThread.join();
}
void initialize(std::function<void(const std::string &)> callback_, size_t maxSize_,
std::chrono::seconds maxTime_)
{
this->callback = callback_;
this->maxSize = maxSize_;
this->maxTime = maxTime_;
this->initialized = true;
this->workerThread = std::thread(&CaptionMonitor::monitor, this);
}
void addWords(const std::vector<std::string> &words)
{
{
std::lock_guard<std::mutex> lock(queueMutex);
for (const auto &word : words) {
wordQueue.push_back(word);
}
this->newDataAvailable = true;
}
condVar.notify_all();
}
private:
void monitor()
{
obs_log(LOG_INFO, "CaptionMonitor::monitor");
auto startTime = std::chrono::steady_clock::now();
while (true) {
std::unique_lock<std::mutex> lock(this->queueMutex);
// wait for new data or stop signal
this->condVar.wait(lock,
[this] { return this->newDataAvailable || this->stop; });
if (this->stop) {
break;
}
if (this->wordQueue.empty()) {
continue;
}
// emit up to maxSize words from the wordQueue
std::vector<std::string> emitted;
while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) {
emitted.push_back(this->wordQueue.front());
this->wordQueue.pop_front();
}
// emit the caption, joining the words with a space
std::string output;
for (const auto &word : emitted) {
output += word + " ";
}
this->callback(output);
// push back the words that were emitted, in reverse order
for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) {
this->wordQueue.push_front(*it);
}
if (this->wordQueue.size() >= this->maxSize ||
std::chrono::steady_clock::now() - startTime >= this->maxTime) {
// flush the queue if it's full or we've reached the max time
size_t words_to_flush =
std::min(this->wordQueue.size(), this->maxSize);
for (size_t i = 0; i < words_to_flush; ++i) {
wordQueue.pop_front();
}
startTime = std::chrono::steady_clock::now();
}
newDataAvailable = false;
}
obs_log(LOG_INFO, "CaptionMonitor::monitor: done");
}
std::deque<std::string> wordQueue;
std::thread workerThread;
std::mutex queueMutex;
std::condition_variable condVar;
std::function<void(std::string)> callback;
size_t maxSize;
std::chrono::seconds maxTime;
bool stop;
bool initialized = false;
bool newDataAvailable = false;
};
#endif // CAPTIONS_THREAD_H

View File

@@ -17,25 +17,13 @@
#include "translation/translation.h"
#include "whisper-utils/silero-vad-onnx.h"
#include "captions-thread.h"
#include "whisper-utils/whisper-processing.h"
#include "whisper-utils/token-buffer-thread.h"
#define MAX_PREPROC_CHANNELS 10
#define MT_ obs_module_text
enum DetectionResult {
DETECTION_RESULT_UNKNOWN = 0,
DETECTION_RESULT_SILENCE = 1,
DETECTION_RESULT_SPEECH = 2,
};
struct DetectionResultWithText {
DetectionResult result;
std::string text;
uint64_t start_timestamp_ms;
uint64_t end_timestamp_ms;
};
struct transcription_filter_data {
obs_source_t *context; // obs filter source (this filter)
size_t channels; // number of channels
@@ -64,7 +52,7 @@ struct transcription_filter_data {
struct circlebuf input_buffers[MAX_PREPROC_CHANNELS];
/* Resampler */
audio_resampler_t *resampler;
audio_resampler_t *resampler_to_whisper;
/* whisper */
std::string whisper_model_path;
@@ -90,15 +78,16 @@ struct transcription_filter_data {
bool translate = false;
std::string source_lang;
std::string target_lang;
std::string translation_output;
bool buffered_output = false;
bool enable_token_ts_dtw = false;
std::string suppress_sentences;
// Last transcription result
std::string last_text;
// Text source to output the subtitles
obs_weak_source_t *text_source;
char *text_source_name;
std::mutex *text_source_mutex;
std::string text_source_name;
// Callback to set the text in the output text source (subtitles)
std::function<void(const DetectionResultWithText &result)> setTextCallback;
// Output file path to write the subtitles
@@ -115,7 +104,7 @@ struct transcription_filter_data {
// translation context
struct translation_context translation_ctx;
CaptionMonitor captions_monitor;
TokenBufferThread captions_monitor;
// ctor
transcription_filter_data()
@@ -125,11 +114,9 @@ struct transcription_filter_data {
copy_buffers[i] = nullptr;
}
context = nullptr;
resampler = nullptr;
resampler_to_whisper = nullptr;
whisper_model_path = "";
whisper_context = nullptr;
text_source = nullptr;
text_source_mutex = nullptr;
whisper_buf_mutex = nullptr;
whisper_ctx_mutex = nullptr;
wshiper_thread_cv = nullptr;

View File

@@ -4,6 +4,7 @@
#include "plugin-support.h"
#include "transcription-filter.h"
#include "transcription-filter-data.h"
#include "transcription-utils.h"
#include "model-utils/model-downloader.h"
#include "whisper-utils/whisper-processing.h"
#include "whisper-utils/whisper-language.h"
@@ -132,18 +133,8 @@ void transcription_filter_destroy(void *data)
obs_log(gf->log_level, "filter destroy");
shutdown_whisper_thread(gf);
if (gf->text_source_name) {
bfree(gf->text_source_name);
gf->text_source_name = nullptr;
}
if (gf->text_source) {
obs_weak_source_release(gf->text_source);
gf->text_source = nullptr;
}
if (gf->resampler) {
audio_resampler_destroy(gf->resampler);
if (gf->resampler_to_whisper) {
audio_resampler_destroy(gf->resampler_to_whisper);
}
{
@@ -159,87 +150,14 @@ void transcription_filter_destroy(void *data)
delete gf->whisper_buf_mutex;
delete gf->whisper_ctx_mutex;
delete gf->wshiper_thread_cv;
delete gf->text_source_mutex;
delete gf;
}
void acquire_weak_text_source_ref(struct transcription_filter_data *gf)
void send_caption_to_source(const std::string &target_source_name, const std::string &str_copy,
struct transcription_filter_data *gf)
{
if (!gf->text_source_name) {
obs_log(gf->log_level, "text_source_name is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
// acquire a weak ref to the new text source
obs_source_t *source = obs_get_source_by_name(gf->text_source_name);
if (source) {
gf->text_source = obs_source_get_weak_source(source);
obs_source_release(source);
if (!gf->text_source) {
obs_log(LOG_ERROR, "failed to get weak source for text source %s",
gf->text_source_name);
}
} else {
obs_log(LOG_ERROR, "text source '%s' not found", gf->text_source_name);
}
}
#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0)
#define is_trail_byte(c) (((c)&0xc0) == 0x80)
inline int lead_byte_length(const uint8_t c)
{
if ((c & 0xe0) == 0xc0) {
return 2;
} else if ((c & 0xf0) == 0xe0) {
return 3;
} else if ((c & 0xf8) == 0xf0) {
return 4;
} else {
return 1;
}
}
inline bool is_valid_lead_byte(const uint8_t *c)
{
const int length = lead_byte_length(c[0]);
if (length == 1) {
return true;
}
if (length == 2 && is_trail_byte(c[1])) {
return true;
}
if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) {
return true;
}
if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) {
return true;
}
return false;
}
void send_caption_to_source(const std::string &str_copy, struct transcription_filter_data *gf)
{
if (!gf->text_source_mutex) {
obs_log(LOG_ERROR, "text_source_mutex is null");
return;
}
if (!gf->text_source) {
// attempt to acquire a weak ref to the text source if it's yet available
acquire_weak_text_source_ref(gf);
}
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
if (!gf->text_source) {
obs_log(gf->log_level, "text_source is null");
return;
}
auto target = obs_weak_source_get_source(gf->text_source);
auto target = obs_get_source_by_name(target_source_name.c_str());
if (!target) {
obs_log(gf->log_level, "text_source target is null");
return;
@@ -267,52 +185,9 @@ void set_text_callback(struct transcription_filter_data *gf,
}
gf->last_sub_render_time = now;
#ifdef _WIN32
// Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs
// 0xf?, and 0xc? becomes 0xe?, so we need to fix it.
std::stringstream ss;
uint8_t *c_str = (uint8_t *)result.text.c_str();
for (size_t i = 0; i < result.text.size(); ++i) {
if (is_lead_byte(c_str[i])) {
// this is a unicode leading byte
// if the next char is 0xff - it's a bug char, replace it with 0x9f
if (c_str[i + 1] == 0xff) {
c_str[i + 1] = 0x9f;
}
if (!is_valid_lead_byte(c_str + i)) {
// This is a bug lead byte, because it's length 3 and the i+2 byte is also
// a lead byte
c_str[i] = c_str[i] - 0x20;
}
} else {
if (c_str[i] >= 0xf8) {
// this may be a malformed lead byte.
// lets see if it becomes a valid lead byte if we "fix" it
uint8_t buf_[4];
buf_[0] = c_str[i] - 0x20;
buf_[1] = c_str[i + 1];
buf_[2] = c_str[i + 2];
buf_[3] = c_str[i + 3];
if (is_valid_lead_byte(buf_)) {
// this is a malformed lead byte, fix it
c_str[i] = c_str[i] - 0x20;
}
}
}
}
std::string str_copy = (char *)c_str;
#else
std::string str_copy = result.text;
#endif
// remove trailing spaces, newlines, tabs or punctuation
str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(),
[](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
})
.base(),
str_copy.end());
// recondition the text
std::string str_copy = fix_utf8(result.text);
str_copy = remove_leading_trailing_nonalpha(str_copy);
if (gf->translate) {
obs_log(gf->log_level, "Translating text. %s -> %s", gf->source_lang.c_str(),
@@ -324,7 +199,13 @@ void set_text_callback(struct transcription_filter_data *gf,
obs_log(LOG_INFO, "Translation: '%s' -> '%s'", str_copy.c_str(),
translated_text.c_str());
}
str_copy = translated_text;
if (gf->translation_output == "none") {
// overwrite the original text with the translated text
str_copy = translated_text;
} else {
// send the translation to the selected source
send_caption_to_source(gf->translation_output, translated_text, gf);
}
} else {
obs_log(gf->log_level, "Failed to translate text");
}
@@ -333,7 +214,7 @@ void set_text_callback(struct transcription_filter_data *gf,
gf->last_text = str_copy;
if (gf->buffered_output) {
gf->captions_monitor.addWords(split_words(str_copy));
gf->captions_monitor.addWords(result.tokens);
}
if (gf->caption_to_stream) {
@@ -344,7 +225,7 @@ void set_text_callback(struct transcription_filter_data *gf,
}
}
if (gf->output_file_path != "" && !gf->text_source_name) {
if (gf->output_file_path != "" && gf->text_source_name.empty()) {
// Check if we should save the sentence
if (gf->save_only_while_recording && !obs_frontend_recording_active()) {
// We are not recording, do not save the sentence to file
@@ -396,7 +277,7 @@ void set_text_callback(struct transcription_filter_data *gf,
} else {
if (!gf->buffered_output) {
// Send the caption to the text source
send_caption_to_source(str_copy, gf);
send_caption_to_source(gf->text_source_name, str_copy, gf);
}
}
};
@@ -427,12 +308,21 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->process_while_muted = obs_data_get_bool(s, "process_while_muted");
gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration");
gf->last_sub_render_time = 0;
gf->buffered_output = obs_data_get_bool(s, "buffered_output");
bool new_buffered_output = obs_data_get_bool(s, "buffered_output");
if (new_buffered_output != gf->buffered_output) {
gf->buffered_output = new_buffered_output;
gf->overlap_ms = gf->buffered_output ? MAX_OVERLAP_SIZE_MSEC
: DEFAULT_OVERLAP_SIZE_MSEC;
gf->overlap_frames =
(size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
}
bool new_translate = obs_data_get_bool(s, "translate");
gf->source_lang = obs_data_get_string(s, "translate_source_language");
gf->target_lang = obs_data_get_string(s, "translate_target_language");
gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context");
gf->translation_output = obs_data_get_string(s, "translate_output");
gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences");
if (new_translate != gf->translate) {
if (new_translate) {
@@ -451,19 +341,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
strcmp(new_text_source_name, "(null)") == 0 ||
strcmp(new_text_source_name, "text_file") == 0 || strlen(new_text_source_name) == 0) {
// new selected text source is not valid, release the old one
if (gf->text_source) {
if (!gf->text_source_mutex) {
obs_log(LOG_ERROR, "text_source_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
old_weak_text_source = gf->text_source;
gf->text_source = nullptr;
}
if (gf->text_source_name) {
bfree(gf->text_source_name);
gf->text_source_name = nullptr;
}
gf->text_source_name.clear();
gf->output_file_path = "";
if (strcmp(new_text_source_name, "text_file") == 0) {
// set the output file path
@@ -475,24 +353,9 @@ void transcription_filter_update(void *data, obs_data_t *s)
}
} else {
// new selected text source is valid, check if it's different from the old one
if (gf->text_source_name == nullptr ||
strcmp(new_text_source_name, gf->text_source_name) != 0) {
if (gf->text_source_name != new_text_source_name) {
// new text source is different from the old one, release the old one
if (gf->text_source) {
if (!gf->text_source_mutex) {
obs_log(LOG_ERROR, "text_source_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
old_weak_text_source = gf->text_source;
gf->text_source = nullptr;
}
if (gf->text_source_name) {
// free the old text source name
bfree(gf->text_source_name);
gf->text_source_name = nullptr;
}
gf->text_source_name = bstrdup(new_text_source_name);
gf->text_source_name = new_text_source_name;
}
}
@@ -507,7 +370,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
}
obs_log(gf->log_level, "update whisper model");
update_whsiper_model_path(gf, s);
update_whsiper_model(gf, s);
obs_log(gf->log_level, "update whisper params");
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
@@ -597,15 +460,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
dst.format = AUDIO_FORMAT_FLOAT_PLANAR;
dst.speakers = convert_speaker_layout((uint8_t)1);
gf->resampler = audio_resampler_create(&dst, &src);
gf->resampler_to_whisper = audio_resampler_create(&dst, &src);
obs_log(gf->log_level, "setup mutexes and condition variables");
gf->whisper_buf_mutex = new std::mutex();
gf->whisper_ctx_mutex = new std::mutex();
gf->wshiper_thread_cv = new std::condition_variable();
gf->text_source_mutex = new std::mutex();
obs_log(gf->log_level, "clear text source data");
gf->text_source = nullptr;
const char *subtitle_sources = obs_data_get_string(settings, "subtitle_sources");
if (subtitle_sources == nullptr || strcmp(subtitle_sources, "none") == 0 ||
strcmp(subtitle_sources, "(null)") == 0 || strlen(subtitle_sources) == 0) {
@@ -619,8 +480,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
// create a new OBS text source called "LocalVocal Subtitles"
obs_source_t *scene_as_source = obs_frontend_get_current_scene();
obs_scene_t *scene = obs_scene_from_source(scene_as_source);
#ifdef _WIN32
source = obs_source_create("text_gdiplus_v2", "LocalVocal Subtitles",
nullptr, nullptr);
#else
source = obs_source_create("text_ft2_source_v2", "LocalVocal Subtitles",
nullptr, nullptr);
#endif
if (source) {
// add source to the current scene
obs_scene_add(scene, source);
@@ -660,11 +526,11 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
}
obs_source_release(scene_as_source);
}
gf->text_source_name = bstrdup("LocalVocal Subtitles");
gf->text_source_name = "LocalVocal Subtitles";
obs_data_set_string(settings, "subtitle_sources", "LocalVocal Subtitles");
} else {
// set the text source name
gf->text_source_name = bstrdup(subtitle_sources);
gf->text_source_name = subtitle_sources;
}
obs_log(gf->log_level, "clear paths and whisper context");
gf->whisper_model_file_currently_loaded = "";
@@ -673,13 +539,14 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
gf->whisper_context = nullptr;
gf->captions_monitor.initialize(
gf,
[gf](const std::string &text) {
obs_log(LOG_INFO, "Captions: %s", text.c_str());
if (gf->buffered_output) {
send_caption_to_source(text, gf);
send_caption_to_source(gf->text_source_name, text, gf);
}
},
20, std::chrono::seconds(10));
30, std::chrono::seconds(10));
obs_log(gf->log_level, "run update");
// get the settings updated on the filter data struct
@@ -791,6 +658,7 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_string(s, "translate_target_language", "__es__");
obs_data_set_default_string(s, "translate_source_language", "__en__");
obs_data_set_default_bool(s, "translate_add_context", true);
obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT);
// Whisper parameters
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
@@ -805,6 +673,7 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_bool(s, "print_realtime", false);
obs_data_set_default_bool(s, "print_timestamps", false);
obs_data_set_default_bool(s, "token_timestamps", false);
obs_data_set_default_bool(s, "dtw_token_timestamps", false);
obs_data_set_default_double(s, "thold_pt", 0.01);
obs_data_set_default_double(s, "thold_ptsum", 0.01);
obs_data_set_default_int(s, "max_len", 0);
@@ -919,6 +788,15 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_list_add_string(prop_src, language.second.c_str(),
language.first.c_str());
}
// add option for routing the translation to an output source
obs_property_t *prop_output = obs_properties_add_list(translation_group, "translate_output",
MT_("translate_output"),
OBS_COMBO_TYPE_LIST,
OBS_COMBO_FORMAT_STRING);
obs_property_list_add_string(prop_output, "Write to captions output", "none");
// TODO add file output option
// obs_property_list_add_string(...
obs_enum_sources(add_sources_to_list, prop_output);
// add callback to enable/disable translation group
obs_property_set_modified_callback(translation_group_prop, [](obs_properties_t *props,
@@ -928,7 +806,7 @@ obs_properties_t *transcription_filter_properties(void *data)
// Show/Hide the translation group
const bool translate_enabled = obs_data_get_bool(settings, "translate");
for (const auto &prop : {"translate_target_language", "translate_source_language",
"translate_add_context"}) {
"translate_add_context", "translate_output"}) {
obs_property_set_visible(obs_properties_get(props, prop),
translate_enabled);
}
@@ -946,21 +824,38 @@ obs_properties_t *transcription_filter_properties(void *data)
for (const std::string &prop_name :
{"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec",
"overlap_size_msec", "step_by_step_processing", "min_sub_duration",
"process_while_muted", "buffered_output", "vad_enabled", "log_level"}) {
"process_while_muted", "buffered_output", "vad_enabled", "log_level",
"suppress_sentences"}) {
obs_property_set_visible(obs_properties_get(props, prop_name.c_str()),
show_hide);
}
return true;
});
obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output"));
obs_property_t *buffered_output_prop =
obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output"));
// add on-change handler for buffered_output
obs_property_set_modified_callback(buffered_output_prop, [](obs_properties_t *props,
obs_property_t *property,
obs_data_t *settings) {
UNUSED_PARAMETER(property);
UNUSED_PARAMETER(props);
// if buffered output is enabled set the overlap to max else set it to default
obs_data_set_int(settings, "overlap_size_msec",
obs_data_get_bool(settings, "buffered_output")
? MAX_OVERLAP_SIZE_MSEC
: DEFAULT_OVERLAP_SIZE_MSEC);
return true;
});
obs_properties_add_bool(ppts, "log_words", MT_("log_words"));
obs_properties_add_bool(ppts, "caption_to_stream", MT_("caption_to_stream"));
obs_properties_add_int_slider(ppts, "buffer_size_msec", MT_("buffer_size_msec"), 1000,
DEFAULT_BUFFER_SIZE_MSEC, 250);
obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"), 50, 300,
50);
obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"),
MIN_OVERLAP_SIZE_MSEC, MAX_OVERLAP_SIZE_MSEC,
(MAX_OVERLAP_SIZE_MSEC - MIN_OVERLAP_SIZE_MSEC) / 5);
obs_property_t *step_by_step_processing = obs_properties_add_bool(
ppts, "step_by_step_processing", MT_("step_by_step_processing"));
@@ -985,10 +880,14 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
obs_property_list_add_int(list, "DEBUG", LOG_DEBUG);
obs_property_list_add_int(list, "DEBUG (Won't show)", LOG_DEBUG);
obs_property_list_add_int(list, "INFO", LOG_INFO);
obs_property_list_add_int(list, "WARNING", LOG_WARNING);
// add a text input for sentences to suppress
obs_properties_add_text(ppts, "suppress_sentences", MT_("suppress_sentences"),
OBS_TEXT_MULTILINE);
obs_properties_t *whisper_params_group = obs_properties_create();
obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"),
OBS_GROUP_NORMAL, whisper_params_group);
@@ -1043,6 +942,9 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_properties_add_bool(whisper_params_group, "print_timestamps", MT_("print_timestamps"));
// bool token_timestamps; // enable token-level timestamps
obs_properties_add_bool(whisper_params_group, "token_timestamps", MT_("token_timestamps"));
// enable DTW timestamps
obs_properties_add_bool(whisper_params_group, "dtw_token_timestamps",
MT_("dtw_token_timestamps"));
// float thold_pt; // timestamp token probability threshold (~0.01)
obs_properties_add_float_slider(whisper_params_group, "thold_pt", MT_("thold_pt"), 0.0f,
1.0f, 0.05f);

View File

@@ -19,6 +19,8 @@ const char *const PLUGIN_INFO_TEMPLATE =
"<a href=\"https://github.com/occ-ai\">OCC AI</a> ❤️ "
"<a href=\"https://www.patreon.com/RoyShilkrot\">Support & Follow</a>";
const char *const SUPPRESS_SENTENCES_DEFAULT = "Thank you for watching\nThank you";
#ifdef __cplusplus
}
#endif

104
src/transcription-utils.cpp Normal file
View File

@@ -0,0 +1,104 @@
#include "transcription-utils.h"
#include <sstream>
#include <algorithm>
#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0)
#define is_trail_byte(c) (((c)&0xc0) == 0x80)
inline int lead_byte_length(const uint8_t c)
{
if ((c & 0xe0) == 0xc0) {
return 2;
} else if ((c & 0xf0) == 0xe0) {
return 3;
} else if ((c & 0xf8) == 0xf0) {
return 4;
} else {
return 1;
}
}
inline bool is_valid_lead_byte(const uint8_t *c)
{
const int length = lead_byte_length(c[0]);
if (length == 1) {
return true;
}
if (length == 2 && is_trail_byte(c[1])) {
return true;
}
if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) {
return true;
}
if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) {
return true;
}
return false;
}
std::string fix_utf8(const std::string &str)
{
#ifdef _WIN32
// Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs
// 0xf?, and 0xc? becomes 0xe?, so we need to fix it.
std::stringstream ss;
uint8_t *c_str = (uint8_t *)str.c_str();
for (size_t i = 0; i < str.size(); ++i) {
if (is_lead_byte(c_str[i])) {
// this is a unicode leading byte
// if the next char is 0xff - it's a bug char, replace it with 0x9f
if (c_str[i + 1] == 0xff) {
c_str[i + 1] = 0x9f;
}
if (!is_valid_lead_byte(c_str + i)) {
// This is a bug lead byte, because it's length 3 and the i+2 byte is also
// a lead byte
c_str[i] = c_str[i] - 0x20;
}
} else {
if (c_str[i] >= 0xf8) {
// this may be a malformed lead byte.
// lets see if it becomes a valid lead byte if we "fix" it
uint8_t buf_[4];
buf_[0] = c_str[i] - 0x20;
buf_[1] = c_str[i + 1];
buf_[2] = c_str[i + 2];
buf_[3] = c_str[i + 3];
if (is_valid_lead_byte(buf_)) {
// this is a malformed lead byte, fix it
c_str[i] = c_str[i] - 0x20;
}
}
}
}
return std::string((char *)c_str);
#else
return str;
#endif
}
/*
* Remove leading and trailing non-alphabetic characters from a string.
* This function is used to remove leading and trailing spaces, newlines, tabs or punctuation.
* @param str: the string to remove leading and trailing non-alphabetic characters from.
* @return: the string with leading and trailing non-alphabetic characters removed.
*/
std::string remove_leading_trailing_nonalpha(const std::string &str)
{
std::string str_copy = str;
// remove trailing spaces, newlines, tabs or punctuation
str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(),
[](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
})
.base(),
str_copy.end());
// remove leading spaces, newlines, tabs or punctuation
str_copy.erase(str_copy.begin(),
std::find_if(str_copy.begin(), str_copy.end(), [](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
}));
return str_copy;
}

View File

@@ -0,0 +1,9 @@
#ifndef TRANSCRIPTION_UTILS_H
#define TRANSCRIPTION_UTILS_H
#include <string>
std::string fix_utf8(const std::string &str);
std::string remove_leading_trailing_nonalpha(const std::string &str);
#endif // TRANSCRIPTION_UTILS_H

View File

@@ -10,7 +10,10 @@
#include <cstdio>
#include <cstdarg>
//#define __DEBUG_SPEECH_PROB___
#include <obs.h>
#include "plugin-support.h"
// #define __DEBUG_SPEECH_PROB___
timestamp_t::timestamp_t(int start_, int end_) : start(start_), end(end_){};
@@ -144,8 +147,8 @@ void VadIterator::predict(const std::vector<float> &data)
float speech =
current_sample -
window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob,
current_sample - window_size_samples);
obs_log(LOG_INFO, "{ start: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate,
speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (temp_end != 0) {
temp_end = 0;
@@ -194,16 +197,18 @@ void VadIterator::predict(const std::vector<float> &data)
float speech =
current_sample -
window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate,
speech_prob, current_sample - window_size_samples);
obs_log(LOG_INFO, "{ speaking: %.3f s (%.3f) %08d}",
1.0 * speech / sample_rate, speech_prob,
current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
} else {
#ifdef __DEBUG_SPEECH_PROB___
float speech =
current_sample -
window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate,
speech_prob, current_sample - window_size_samples);
obs_log(LOG_INFO, "{ silence: %.3f s (%.3f) %08d}",
1.0 * speech / sample_rate, speech_prob,
current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
return;
@@ -215,8 +220,8 @@ void VadIterator::predict(const std::vector<float> &data)
float speech =
current_sample - window_size_samples -
speech_pad_samples; // minus window_size_samples to get precise start time point.
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob,
current_sample - window_size_samples);
obs_log(LOG_INFO, "{ end: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate,
speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (triggered == true) {
if (temp_end == 0) {
@@ -285,7 +290,7 @@ void VadIterator::collect_chunks(const std::vector<float> &input_wav,
output_wav.clear();
for (size_t i = 0; i < speeches.size(); i++) {
#ifdef __DEBUG_SPEECH_PROB___
std::cout << speeches[i].c_str() << std::endl;
obs_log(LOG_INFO, "%s", speeches[i].string().c_str());
#endif //#ifdef __DEBUG_SPEECH_PROB___
std::vector<float> slice(&input_wav[speeches[i].start],
&input_wav[speeches[i].end]);

View File

@@ -0,0 +1,131 @@
#include "token-buffer-thread.h"
#include "./whisper-utils.h"
TokenBufferThread::~TokenBufferThread()
{
{
std::lock_guard<std::mutex> lock(queueMutex);
stop = true;
}
condVar.notify_all();
workerThread.join();
}
void TokenBufferThread::initialize(struct transcription_filter_data *gf_,
std::function<void(const std::string &)> callback_,
size_t maxSize_, std::chrono::seconds maxTime_)
{
this->gf = gf_;
this->callback = callback_;
this->maxSize = maxSize_;
this->maxTime = maxTime_;
this->initialized = true;
this->workerThread = std::thread(&TokenBufferThread::monitor, this);
}
void TokenBufferThread::log_token_vector(const std::vector<whisper_token_data> &tokens)
{
std::string output;
for (const auto &token : tokens) {
const char *token_str = whisper_token_to_str(gf->whisper_context, token.id);
output += token_str;
}
obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str());
}
void TokenBufferThread::addWords(const std::vector<whisper_token_data> &words)
{
obs_log(LOG_INFO, "TokenBufferThread::addWords");
{
std::lock_guard<std::mutex> lock(queueMutex);
// convert current wordQueue to vector
std::vector<whisper_token_data> currentWords(wordQueue.begin(), wordQueue.end());
log_token_vector(currentWords);
log_token_vector(words);
// run reconstructSentence
std::vector<whisper_token_data> reconstructed =
reconstructSentence(currentWords, words);
log_token_vector(reconstructed);
// clear the wordQueue
wordQueue.clear();
// add the reconstructed sentence to the wordQueue
for (const auto &word : reconstructed) {
wordQueue.push_back(word);
}
newDataAvailable = true;
}
condVar.notify_all();
}
void TokenBufferThread::monitor()
{
obs_log(LOG_INFO, "TokenBufferThread::monitor");
auto startTime = std::chrono::steady_clock::now();
while (this->initialized && !this->stop) {
std::unique_lock<std::mutex> lock(this->queueMutex);
// wait for new data or stop signal
this->condVar.wait(lock, [this] { return this->newDataAvailable || this->stop; });
if (this->stop) {
break;
}
if (this->wordQueue.empty()) {
continue;
}
if (this->gf->whisper_context == nullptr) {
continue;
}
// emit up to maxSize words from the wordQueue
std::vector<whisper_token_data> emitted;
while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) {
emitted.push_back(this->wordQueue.front());
this->wordQueue.pop_front();
}
obs_log(LOG_INFO, "TokenBufferThread::monitor: emitting %d words", emitted.size());
log_token_vector(emitted);
// emit the caption from the tokens
std::string output;
for (const auto &token : emitted) {
const char *token_str =
whisper_token_to_str(this->gf->whisper_context, token.id);
output += token_str;
}
this->callback(output);
// push back the words that were emitted, in reverse order
for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) {
this->wordQueue.push_front(*it);
}
// check if we need to flush the queue
auto elapsedTime = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - startTime);
if (this->wordQueue.size() >= this->maxSize || elapsedTime >= this->maxTime) {
// flush the queue if it's full or we've reached the max time
size_t words_to_flush = std::min(this->wordQueue.size(), this->maxSize);
// make sure we leave at least 3 words in the queue
size_t words_remaining = this->wordQueue.size() - words_to_flush;
if (words_remaining < 3) {
words_to_flush -= 3 - words_remaining;
}
obs_log(LOG_INFO, "TokenBufferThread::monitor: flushing %d words",
words_to_flush);
for (size_t i = 0; i < words_to_flush; ++i) {
wordQueue.pop_front();
}
startTime = std::chrono::steady_clock::now();
}
newDataAvailable = false;
}
obs_log(LOG_INFO, "TokenBufferThread::monitor: done");
}

View File

@@ -0,0 +1,49 @@
#ifndef TOKEN_BUFFER_THREAD_H
#define TOKEN_BUFFER_THREAD_H
#include <queue>
#include <vector>
#include <chrono>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
#include <string>
#include <obs.h>
#include <whisper.h>
#include "plugin-support.h"
struct transcription_filter_data;
class TokenBufferThread {
public:
// default constructor
TokenBufferThread() = default;
~TokenBufferThread();
void initialize(struct transcription_filter_data *gf,
std::function<void(const std::string &)> callback_, size_t maxSize_,
std::chrono::seconds maxTime_);
void addWords(const std::vector<whisper_token_data> &words);
private:
void monitor();
void log_token_vector(const std::vector<whisper_token_data> &tokens);
struct transcription_filter_data *gf;
std::deque<whisper_token_data> wordQueue;
std::thread workerThread;
std::mutex queueMutex;
std::condition_variable condVar;
std::function<void(std::string)> callback;
size_t maxSize;
std::chrono::seconds maxTime;
bool stop;
bool initialized = false;
bool newDataAvailable = false;
};
#endif

View File

@@ -5,6 +5,7 @@
#include "plugin-support.h"
#include "transcription-filter-data.h"
#include "whisper-processing.h"
#include "whisper-utils.h"
#include <algorithm>
#include <cctype>
@@ -109,7 +110,8 @@ bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float v
return true;
}
struct whisper_context *init_whisper_context(const std::string &model_path_in)
struct whisper_context *init_whisper_context(const std::string &model_path_in,
struct transcription_filter_data *gf)
{
std::string model_path = model_path_in;
@@ -131,14 +133,15 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in)
whisper_log_set(
[](enum ggml_log_level level, const char *text, void *user_data) {
UNUSED_PARAMETER(level);
UNUSED_PARAMETER(user_data);
struct transcription_filter_data *ctx =
static_cast<struct transcription_filter_data *>(user_data);
// remove trailing newline
char *text_copy = bstrdup(text);
text_copy[strcspn(text_copy, "\n")] = 0;
obs_log(LOG_INFO, "Whisper: %s", text_copy);
obs_log(ctx->log_level, "Whisper: %s", text_copy);
bfree(text_copy);
},
nullptr);
gf);
struct whisper_context_params cparams = whisper_context_default_params();
#ifdef LOCALVOCAL_WITH_CUDA
@@ -152,6 +155,16 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in)
obs_log(LOG_INFO, "Using CPU for inference");
#endif
cparams.dtw_token_timestamps = gf->enable_token_ts_dtw;
if (gf->enable_token_ts_dtw) {
obs_log(LOG_INFO, "DTW token timestamps enabled");
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
// cparams.dtw_n_top = 4;
} else {
obs_log(LOG_INFO, "DTW token timestamps disabled");
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
}
struct whisper_context *ctx = nullptr;
try {
#ifdef _WIN32
@@ -196,16 +209,19 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in)
}
struct DetectionResultWithText run_whisper_inference(struct transcription_filter_data *gf,
const float *pcm32f_data, size_t pcm32f_size)
const float *pcm32f_data, size_t pcm32f_size,
bool zero_start)
{
UNUSED_PARAMETER(zero_start);
if (gf == nullptr) {
obs_log(LOG_ERROR, "run_whisper_inference: gf is null");
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
}
if (pcm32f_data == nullptr || pcm32f_size == 0) {
obs_log(LOG_ERROR, "run_whisper_inference: pcm32f_data is null or size is 0");
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
}
obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__,
@@ -215,7 +231,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "whisper context is null");
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
}
// Get the duration in ms since the beginning of the stream (gf->start_timestamp_ms)
@@ -234,47 +250,92 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what());
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
}
if (whisper_full_result != 0) {
obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result);
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
} else {
// duration in ms
const uint64_t duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE);
const int n_segment = 0;
const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment);
// const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment);
const int64_t t0 = offset_ms;
const int64_t t1 = offset_ms + duration_ms;
float sentence_p = 0.0f;
const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment);
std::string text = "";
std::string tokenIds = "";
std::vector<whisper_token_data> tokens;
bool end = false;
for (int j = 0; j < n_tokens; ++j) {
sentence_p += whisper_full_get_token_p(gf->whisper_context, n_segment, j);
// get token
whisper_token_data token =
whisper_full_get_token_data(gf->whisper_context, n_segment, j);
const char *token_str = whisper_token_to_str(gf->whisper_context, token.id);
bool keep = !end;
// if the token starts with '[' and ends with ']', don't keep it
if (token_str[0] == '[' && token_str[strlen(token_str) - 1] == ']') {
keep = false;
}
if ((j == n_tokens - 2 || j == n_tokens - 3) && token.p < 0.5) {
keep = false;
}
// if the second to last token is .id == 13 ('.'), don't keep it
if (j == n_tokens - 2 && token.id == 13) {
keep = false;
}
// token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json
if (token.id > 50540 && token.id <= 51865) {
obs_log(gf->log_level,
"Large time token found (%d), this shouldn't happen",
token.id);
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
}
if (keep) {
text += token_str;
tokenIds += std::to_string(token.id) + " (" +
std::string(token_str) + "), ";
tokens.push_back(token);
}
obs_log(gf->log_level, "Token %d: %d, %s, p: %.3f, dtw: %ld [keep: %d]", j,
token.id, token_str, token.p, token.t_dtw, keep);
}
sentence_p /= (float)n_tokens;
obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str());
obs_log(gf->log_level, "Token IDs: %s", tokenIds.c_str());
// convert text to lowercase
std::string text_lower(text);
std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower);
// trim whitespace (use lambda)
text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
text_lower.end());
// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
std::string suppress_sentences_copy = gf->suppress_sentences;
size_t pos = 0;
std::string token;
while ((pos = suppress_sentences_copy.find("\n")) != std::string::npos) {
token = suppress_sentences_copy.substr(0, pos);
suppress_sentences_copy.erase(0, pos + 1);
if (text == suppress_sentences_copy) {
obs_log(gf->log_level, "Suppressing sentence: %s",
text.c_str());
return {DETECTION_RESULT_SUPPRESSED, "", 0, 0, {}};
}
}
}
if (gf->log_words) {
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
to_timestamp(t1).c_str(), sentence_p, text_lower.c_str());
to_timestamp(t1).c_str(), sentence_p, text.c_str());
}
if (text_lower.empty() || text_lower == ".") {
return {DETECTION_RESULT_SILENCE, "", 0, 0};
if (text.empty() || text == ".") {
return {DETECTION_RESULT_SILENCE, "", 0, 0, {}};
}
return {DETECTION_RESULT_SPEECH, text_lower, offset_ms, offset_ms + duration_ms};
return {DETECTION_RESULT_SPEECH, text, offset_ms, offset_ms + duration_ms, tokens};
}
}
@@ -307,14 +368,14 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
num_new_frames_from_infos -= info_from_buf.frames;
circlebuf_push_front(&gf->info_buffer, &info_from_buf,
size_of_audio_info);
last_step_in_segment =
true; // this is the final step in the segment
// this is the final step in the segment
last_step_in_segment = true;
break;
}
}
obs_log(gf->log_level,
"with %lu remaining to full segment, popped %d info-frames, pushing into buffer at %lu",
"with %lu remaining to full segment, popped %d info-frames, pushing at %lu (overlap)",
remaining_frames_to_full_segment, num_new_frames_from_infos,
gf->last_num_frames);
@@ -340,7 +401,7 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
}
} else {
gf->last_num_frames = num_new_frames_from_infos;
obs_log(gf->log_level, "first segment, %d frames to process",
obs_log(gf->log_level, "first segment, no overlap exists, %d frames to process",
(int)(gf->last_num_frames));
}
@@ -352,51 +413,92 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
auto start = std::chrono::high_resolution_clock::now();
// resample to 16kHz
float *output[MAX_PREPROC_CHANNELS];
uint32_t out_frames;
float *resampled_16khz[MAX_PREPROC_CHANNELS];
uint32_t resampled_16khz_frames;
uint64_t ts_offset;
audio_resampler_resample(gf->resampler, (uint8_t **)output, &out_frames, &ts_offset,
audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz,
&resampled_16khz_frames, &ts_offset,
(const uint8_t **)gf->copy_buffers, (uint32_t)gf->last_num_frames);
obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels, (int)out_frames,
(float)out_frames / WHISPER_SAMPLE_RATE * 1000.0f);
obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels,
(int)resampled_16khz_frames,
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
bool skipped_inference = false;
uint32_t speech_start_frame = 0;
uint32_t speech_end_frame = out_frames;
uint32_t speech_end_frame = resampled_16khz_frames;
if (gf->vad_enabled) {
std::vector<float> vad_input(output[0], output[0] + out_frames);
std::vector<float> vad_input(resampled_16khz[0],
resampled_16khz[0] + resampled_16khz_frames);
gf->vad->process(vad_input);
auto stamps = gf->vad->get_speech_timestamps();
std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
if (stamps.size() == 0) {
obs_log(gf->log_level, "VAD detected no speech in %d frames",
resampled_16khz_frames);
skipped_inference = true;
// prevent copying the buffer to the beginning (overlap)
gf->last_num_frames = 0;
last_step_in_segment = false;
} else {
speech_start_frame = stamps[0].start;
speech_start_frame = (stamps[0].start < 3000) ? 0 : stamps[0].start;
speech_end_frame = stamps.back().end;
obs_log(gf->log_level, "VAD detected speech from %d to %d",
speech_start_frame, speech_end_frame);
uint32_t number_of_frames = speech_end_frame - speech_start_frame;
obs_log(gf->log_level,
"VAD detected speech from %d to %d (%d frames, %d ms)",
speech_start_frame, speech_end_frame, number_of_frames,
number_of_frames * 1000 / WHISPER_SAMPLE_RATE);
// if the speech segment is less than 1 second - put the audio back into the buffer
// to be handled in the next iteration
if (number_of_frames > 0 && number_of_frames < WHISPER_SAMPLE_RATE) {
// convert speech_start_frame and speech_end_frame to original sample rate
speech_start_frame =
speech_start_frame * gf->sample_rate / WHISPER_SAMPLE_RATE;
speech_end_frame =
speech_end_frame * gf->sample_rate / WHISPER_SAMPLE_RATE;
number_of_frames = speech_end_frame - speech_start_frame;
// use memmove to copy the speech segment to the beginning of the buffer
for (size_t c = 0; c < gf->channels; c++) {
memmove(gf->copy_buffers[c],
gf->copy_buffers[c] + speech_start_frame,
number_of_frames * sizeof(float));
}
obs_log(gf->log_level,
"Speech segment is less than 1 second, moving %d to %d (len %d) to buffer start",
speech_start_frame, speech_end_frame, number_of_frames);
// no processing of the segment
skipped_inference = true;
// reset the last_num_frames to the number of frames in the buffer
gf->last_num_frames = number_of_frames;
// prevent copying the buffer to the beginning (overlap)
last_step_in_segment = false;
}
}
}
if (!skipped_inference) {
// run inference
const struct DetectionResultWithText inference_result = run_whisper_inference(
gf, output[0] + speech_start_frame, speech_end_frame - speech_start_frame);
gf, resampled_16khz[0] + speech_start_frame,
speech_end_frame - speech_start_frame, speech_start_frame == 0);
if (inference_result.result == DETECTION_RESULT_SPEECH) {
// output inference result to a text source
set_text_callback(gf, inference_result);
} else if (inference_result.result == DETECTION_RESULT_SILENCE) {
// output inference result to a text source
set_text_callback(gf, {inference_result.result, "[silence]", 0, 0});
set_text_callback(gf, {inference_result.result, "[silence]", 0, 0, {}});
}
} else {
if (gf->log_words) {
obs_log(LOG_INFO, "skipping inference");
}
set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0});
set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0, {}});
}
// end of timer
@@ -407,6 +509,12 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
(int)duration);
if (last_step_in_segment) {
const uint64_t overlap_size_ms =
(uint64_t)(gf->overlap_frames * 1000 / gf->sample_rate);
obs_log(gf->log_level,
"copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning",
gf->overlap_frames, overlap_size_ms,
gf->last_num_frames - gf->overlap_frames);
for (size_t c = 0; c < gf->channels; c++) {
// This is the last step in the segment - reset the copy buffer (include overlap frames)
// move overlap frames from the end of the last copy_buffers to the beginning
@@ -416,8 +524,8 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
// zero out the rest of the buffer, just in case
memset(gf->copy_buffers[c] + gf->overlap_frames, 0,
(gf->frames - gf->overlap_frames) * sizeof(float));
gf->last_num_frames = gf->overlap_frames;
}
gf->last_num_frames = gf->overlap_frames;
}
}

View File

@@ -1,12 +1,32 @@
#ifndef WHISPER_PROCESSING_H
#define WHISPER_PROCESSING_H
#include <whisper.h>
// buffer size in msec
#define DEFAULT_BUFFER_SIZE_MSEC 2000
#define DEFAULT_BUFFER_SIZE_MSEC 3000
// overlap in msec
#define DEFAULT_OVERLAP_SIZE_MSEC 100
#define MAX_OVERLAP_SIZE_MSEC 1000
#define MIN_OVERLAP_SIZE_MSEC 100
enum DetectionResult {
DETECTION_RESULT_UNKNOWN = 0,
DETECTION_RESULT_SILENCE = 1,
DETECTION_RESULT_SPEECH = 2,
DETECTION_RESULT_SUPPRESSED = 3,
};
struct DetectionResultWithText {
DetectionResult result;
std::string text;
uint64_t start_timestamp_ms;
uint64_t end_timestamp_ms;
std::vector<whisper_token_data> tokens;
};
void whisper_loop(void *data);
struct whisper_context *init_whisper_context(const std::string &model_path);
struct whisper_context *init_whisper_context(const std::string &model_path,
struct transcription_filter_data *gf);
#endif // WHISPER_PROCESSING_H

View File

@@ -5,7 +5,7 @@
#include <obs-module.h>
void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s)
void update_whsiper_model(struct transcription_filter_data *gf, obs_data_t *s)
{
// update the whisper model path
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
@@ -13,9 +13,12 @@ void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t
if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
is_external_model) {
// model path changed, reload the model
obs_log(gf->log_level, "model path changed from %s to %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
if (gf->whisper_model_path != new_model_path) {
// model path changed
obs_log(gf->log_level, "model path changed from %s to %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
}
// check if the new model is external file
if (!is_external_model) {
@@ -76,6 +79,21 @@ void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t
obs_log(gf->log_level, "Model path did not change: %s == %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
}
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
// dtw_token_timestamps changed
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
shutdown_whisper_thread(gf);
start_whisper_thread_with_path(gf, gf->whisper_model_path);
} else {
// dtw_token_timestamps did not change
obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
}
}
void shutdown_whisper_thread(struct transcription_filter_data *gf)
@@ -122,9 +140,12 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const
#else
std::string silero_vad_model_path = silero_vad_model_file;
#endif
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE));
// roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
// for silero vad parameters
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 1000,
200, 250));
gf->whisper_context = init_whisper_context(path);
gf->whisper_context = init_whisper_context(path, gf);
if (gf->whisper_context == nullptr) {
obs_log(LOG_ERROR, "Failed to initialize whisper context");
return;
@@ -133,3 +154,100 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
}
// Finds start of 2-token overlap between two sequences of tokens
// Returns a pair of indices of the first overlapping tokens in the two sequences
// If no overlap is found, the function returns {-1, -1}
// Allows for a single token mismatch in the overlap
std::pair<int, int> findStartOfOverlap(const std::vector<whisper_token_data> &seq1,
const std::vector<whisper_token_data> &seq2)
{
if (seq1.empty() || seq2.empty() || seq1.size() == 1 || seq2.size() == 1) {
return {-1, -1};
}
for (size_t i = seq1.size() - 2; i >= seq1.size() / 2; --i) {
for (size_t j = 0; j < seq2.size() - 1; ++j) {
if (seq1[i].id == seq2[j].id) {
// Check if the next token in both sequences is the same
if (seq1[i + 1].id == seq2[j + 1].id) {
return {i, j};
}
// 1-skip check on seq1
if (i + 2 < seq1.size() && seq1[i + 2].id == seq2[j + 1].id) {
return {i, j};
}
// 1-skip check on seq2
if (j + 2 < seq2.size() && seq1[i + 1].id == seq2[j + 2].id) {
return {i, j};
}
}
}
}
return {-1, -1};
}
// Function to reconstruct a whole sentence from two sentences using overlap info
// If no overlap is found, the function returns the concatenation of the two sequences
std::vector<whisper_token_data> reconstructSentence(const std::vector<whisper_token_data> &seq1,
const std::vector<whisper_token_data> &seq2)
{
auto overlap = findStartOfOverlap(seq1, seq2);
std::vector<whisper_token_data> reconstructed;
if (overlap.first == -1 || overlap.second == -1) {
if (seq1.empty() && seq2.empty()) {
return reconstructed;
}
if (seq1.empty()) {
return seq2;
}
if (seq2.empty()) {
return seq1;
}
// Return concat of seq1 and seq2 if no overlap found
// check if the last token of seq1 == the first token of seq2
if (seq1.back().id == seq2.front().id) {
// don't add the last token of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1);
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
} else if (seq2.size() > 1ull && seq1.back().id == seq2[1].id) {
// check if the last token of seq1 == the second token of seq2
// don't add the last token of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1);
// don't add the first token of seq2
reconstructed.insert(reconstructed.end(), seq2.begin() + 1, seq2.end());
} else if (seq1.size() > 1ull && seq1[seq1.size() - 2].id == seq2.front().id) {
// check if the second to last token of seq1 == the first token of seq2
// don't add the last two tokens of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 2);
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
} else {
// add all tokens of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end());
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
}
return reconstructed;
}
// Add tokens from the first sequence up to the overlap
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.begin() + overlap.first);
// Determine the length of the overlap
size_t overlapLength = 0;
while (overlap.first + overlapLength < seq1.size() &&
overlap.second + overlapLength < seq2.size() &&
seq1[overlap.first + overlapLength].id == seq2[overlap.second + overlapLength].id) {
overlapLength++;
}
// Add overlapping tokens
reconstructed.insert(reconstructed.end(), seq1.begin() + overlap.first,
seq1.begin() + overlap.first + overlapLength);
// Add remaining tokens from the second sequence
reconstructed.insert(reconstructed.end(), seq2.begin() + overlap.second + overlapLength,
seq2.end());
return reconstructed;
}

View File

@@ -7,8 +7,13 @@
#include <string>
void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s);
void update_whsiper_model(struct transcription_filter_data *gf, obs_data_t *s);
void shutdown_whisper_thread(struct transcription_filter_data *gf);
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path);
std::pair<int, int> findStartOfOverlap(const std::vector<whisper_token_data> &seq1,
const std::vector<whisper_token_data> &seq2);
std::vector<whisper_token_data> reconstructSentence(const std::vector<whisper_token_data> &seq1,
const std::vector<whisper_token_data> &seq2);
#endif /* WHISPER_UTILS_H */