mirror of
https://github.com/royshil/obs-localvocal.git
synced 2026-01-10 12:58:10 -05:00
Overlap analysis (#92)
* Update buffer size and overlap size in whisper-processing.h and default buffer size in msec in transcription-filter.cpp * Update buffer size and overlap size in whisper-processing.h and default buffer size in msec in transcription-filter.cpp * Update suppress_sentences in en-US.ini and transcription-filter-data.h * Update suppress_sentences and fix whitespace in transcription-filter-data.h, whisper-processing.h, transcription-utils.cpp, and transcription-filter.h * Update whisper-processing.cpp and whisper-utils.cpp files * Update findStartOfOverlap function signature to use int instead of size_t * Update Whispercpp_Build_GIT_TAG to use commit 7395c70a748753e3800b63e3422a2b558a097c80 in BuildWhispercpp.cmake * Update buffer size and overlap size in whisper-processing.h and default buffer size in msec in transcription-filter.cpp * Update unused parameter in transcription-filter-properties function * Update log level and add suppress_sentences feature in transcription-filter.cpp and whisper-processing.cpp * Add translation output feature in en-US.ini and transcription-filter-data.h * Add DTW token timestamps and buffered output feature * trigger rebuild * Refactor remove_leading_trailing_nonalpha function to improve readability and performance * Refactor is_lead_byte and is_trail_byte macros for improved readability and maintainability * Refactor is_lead_byte and is_trail_byte macros for improved readability and maintainability * trigger build
This commit is contained in:
@@ -86,12 +86,14 @@ target_sources(
|
||||
PRIVATE src/plugin-main.c
|
||||
src/transcription-filter.cpp
|
||||
src/transcription-filter.c
|
||||
src/transcription-utils.cpp
|
||||
src/model-utils/model-downloader.cpp
|
||||
src/model-utils/model-downloader-ui.cpp
|
||||
src/model-utils/model-infos.cpp
|
||||
src/whisper-utils/whisper-processing.cpp
|
||||
src/whisper-utils/whisper-utils.cpp
|
||||
src/whisper-utils/silero-vad-onnx.cpp
|
||||
src/whisper-utils/token-buffer-thread.cpp
|
||||
src/translation/translation.cpp
|
||||
src/utils.cpp)
|
||||
|
||||
|
||||
@@ -107,12 +107,12 @@ elseif(WIN32)
|
||||
install(FILES ${WHISPER_DLLS} DESTINATION "obs-plugins/64bit")
|
||||
|
||||
else()
|
||||
set(Whispercpp_Build_GIT_TAG "f22d27a385d34b1e544031efe8aa2e3d73922791")
|
||||
set(Whispercpp_Build_GIT_TAG "7395c70a748753e3800b63e3422a2b558a097c80")
|
||||
set(WHISPER_EXTRA_CXX_FLAGS "-fPIC")
|
||||
set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_BLAS=OFF -DWHISPER_CUBLAS=OFF -DWHISPER_OPENBLAS=OFF -DWHISPER_NO_AVX=ON
|
||||
-DWHISPER_NO_AVX2=ON)
|
||||
|
||||
# On Linux and MacOS build a static Whisper library
|
||||
# On Linux build a static Whisper library
|
||||
ExternalProject_Add(
|
||||
Whispercpp_Build
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP true
|
||||
@@ -133,7 +133,7 @@ else()
|
||||
|
||||
ExternalProject_Get_Property(Whispercpp_Build INSTALL_DIR)
|
||||
|
||||
# on Linux and MacOS add the static Whisper library to the link line
|
||||
# add the static Whisper library to the link line
|
||||
add_library(Whispercpp::Whisper STATIC IMPORTED)
|
||||
set_target_properties(
|
||||
Whispercpp::Whisper
|
||||
|
||||
@@ -51,3 +51,7 @@ translate_add_context="Translate with context"
|
||||
whisper_translate="Translate to English (Whisper)"
|
||||
buffer_size_msec="Buffer size (ms)"
|
||||
overlap_size_msec="Overlap size (ms)"
|
||||
suppress_sentences="Suppress sentences (each line)"
|
||||
translate_output="Translation output"
|
||||
dtw_token_timestamps="DTW token timestamps"
|
||||
buffered_output="Buffered output (Experimental)"
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
#ifndef CAPTIONS_THREAD_H
|
||||
#define CAPTIONS_THREAD_H
|
||||
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include <obs.h>
|
||||
|
||||
#include "plugin-support.h"
|
||||
|
||||
class CaptionMonitor {
|
||||
public:
|
||||
// default constructor
|
||||
CaptionMonitor() = default;
|
||||
|
||||
~CaptionMonitor()
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(queueMutex);
|
||||
stop = true;
|
||||
}
|
||||
condVar.notify_all();
|
||||
workerThread.join();
|
||||
}
|
||||
|
||||
void initialize(std::function<void(const std::string &)> callback_, size_t maxSize_,
|
||||
std::chrono::seconds maxTime_)
|
||||
{
|
||||
this->callback = callback_;
|
||||
this->maxSize = maxSize_;
|
||||
this->maxTime = maxTime_;
|
||||
this->initialized = true;
|
||||
this->workerThread = std::thread(&CaptionMonitor::monitor, this);
|
||||
}
|
||||
|
||||
void addWords(const std::vector<std::string> &words)
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(queueMutex);
|
||||
for (const auto &word : words) {
|
||||
wordQueue.push_back(word);
|
||||
}
|
||||
this->newDataAvailable = true;
|
||||
}
|
||||
condVar.notify_all();
|
||||
}
|
||||
|
||||
private:
|
||||
void monitor()
|
||||
{
|
||||
obs_log(LOG_INFO, "CaptionMonitor::monitor");
|
||||
auto startTime = std::chrono::steady_clock::now();
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(this->queueMutex);
|
||||
// wait for new data or stop signal
|
||||
this->condVar.wait(lock,
|
||||
[this] { return this->newDataAvailable || this->stop; });
|
||||
|
||||
if (this->stop) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (this->wordQueue.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// emit up to maxSize words from the wordQueue
|
||||
std::vector<std::string> emitted;
|
||||
while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) {
|
||||
emitted.push_back(this->wordQueue.front());
|
||||
this->wordQueue.pop_front();
|
||||
}
|
||||
// emit the caption, joining the words with a space
|
||||
std::string output;
|
||||
for (const auto &word : emitted) {
|
||||
output += word + " ";
|
||||
}
|
||||
this->callback(output);
|
||||
// push back the words that were emitted, in reverse order
|
||||
for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) {
|
||||
this->wordQueue.push_front(*it);
|
||||
}
|
||||
|
||||
if (this->wordQueue.size() >= this->maxSize ||
|
||||
std::chrono::steady_clock::now() - startTime >= this->maxTime) {
|
||||
// flush the queue if it's full or we've reached the max time
|
||||
size_t words_to_flush =
|
||||
std::min(this->wordQueue.size(), this->maxSize);
|
||||
for (size_t i = 0; i < words_to_flush; ++i) {
|
||||
wordQueue.pop_front();
|
||||
}
|
||||
startTime = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
newDataAvailable = false;
|
||||
}
|
||||
obs_log(LOG_INFO, "CaptionMonitor::monitor: done");
|
||||
}
|
||||
|
||||
std::deque<std::string> wordQueue;
|
||||
std::thread workerThread;
|
||||
std::mutex queueMutex;
|
||||
std::condition_variable condVar;
|
||||
std::function<void(std::string)> callback;
|
||||
size_t maxSize;
|
||||
std::chrono::seconds maxTime;
|
||||
bool stop;
|
||||
bool initialized = false;
|
||||
bool newDataAvailable = false;
|
||||
};
|
||||
|
||||
#endif // CAPTIONS_THREAD_H
|
||||
@@ -17,25 +17,13 @@
|
||||
|
||||
#include "translation/translation.h"
|
||||
#include "whisper-utils/silero-vad-onnx.h"
|
||||
#include "captions-thread.h"
|
||||
#include "whisper-utils/whisper-processing.h"
|
||||
#include "whisper-utils/token-buffer-thread.h"
|
||||
|
||||
#define MAX_PREPROC_CHANNELS 10
|
||||
|
||||
#define MT_ obs_module_text
|
||||
|
||||
enum DetectionResult {
|
||||
DETECTION_RESULT_UNKNOWN = 0,
|
||||
DETECTION_RESULT_SILENCE = 1,
|
||||
DETECTION_RESULT_SPEECH = 2,
|
||||
};
|
||||
|
||||
struct DetectionResultWithText {
|
||||
DetectionResult result;
|
||||
std::string text;
|
||||
uint64_t start_timestamp_ms;
|
||||
uint64_t end_timestamp_ms;
|
||||
};
|
||||
|
||||
struct transcription_filter_data {
|
||||
obs_source_t *context; // obs filter source (this filter)
|
||||
size_t channels; // number of channels
|
||||
@@ -64,7 +52,7 @@ struct transcription_filter_data {
|
||||
struct circlebuf input_buffers[MAX_PREPROC_CHANNELS];
|
||||
|
||||
/* Resampler */
|
||||
audio_resampler_t *resampler;
|
||||
audio_resampler_t *resampler_to_whisper;
|
||||
|
||||
/* whisper */
|
||||
std::string whisper_model_path;
|
||||
@@ -90,15 +78,16 @@ struct transcription_filter_data {
|
||||
bool translate = false;
|
||||
std::string source_lang;
|
||||
std::string target_lang;
|
||||
std::string translation_output;
|
||||
bool buffered_output = false;
|
||||
bool enable_token_ts_dtw = false;
|
||||
std::string suppress_sentences;
|
||||
|
||||
// Last transcription result
|
||||
std::string last_text;
|
||||
|
||||
// Text source to output the subtitles
|
||||
obs_weak_source_t *text_source;
|
||||
char *text_source_name;
|
||||
std::mutex *text_source_mutex;
|
||||
std::string text_source_name;
|
||||
// Callback to set the text in the output text source (subtitles)
|
||||
std::function<void(const DetectionResultWithText &result)> setTextCallback;
|
||||
// Output file path to write the subtitles
|
||||
@@ -115,7 +104,7 @@ struct transcription_filter_data {
|
||||
// translation context
|
||||
struct translation_context translation_ctx;
|
||||
|
||||
CaptionMonitor captions_monitor;
|
||||
TokenBufferThread captions_monitor;
|
||||
|
||||
// ctor
|
||||
transcription_filter_data()
|
||||
@@ -125,11 +114,9 @@ struct transcription_filter_data {
|
||||
copy_buffers[i] = nullptr;
|
||||
}
|
||||
context = nullptr;
|
||||
resampler = nullptr;
|
||||
resampler_to_whisper = nullptr;
|
||||
whisper_model_path = "";
|
||||
whisper_context = nullptr;
|
||||
text_source = nullptr;
|
||||
text_source_mutex = nullptr;
|
||||
whisper_buf_mutex = nullptr;
|
||||
whisper_ctx_mutex = nullptr;
|
||||
wshiper_thread_cv = nullptr;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "plugin-support.h"
|
||||
#include "transcription-filter.h"
|
||||
#include "transcription-filter-data.h"
|
||||
#include "transcription-utils.h"
|
||||
#include "model-utils/model-downloader.h"
|
||||
#include "whisper-utils/whisper-processing.h"
|
||||
#include "whisper-utils/whisper-language.h"
|
||||
@@ -132,18 +133,8 @@ void transcription_filter_destroy(void *data)
|
||||
obs_log(gf->log_level, "filter destroy");
|
||||
shutdown_whisper_thread(gf);
|
||||
|
||||
if (gf->text_source_name) {
|
||||
bfree(gf->text_source_name);
|
||||
gf->text_source_name = nullptr;
|
||||
}
|
||||
|
||||
if (gf->text_source) {
|
||||
obs_weak_source_release(gf->text_source);
|
||||
gf->text_source = nullptr;
|
||||
}
|
||||
|
||||
if (gf->resampler) {
|
||||
audio_resampler_destroy(gf->resampler);
|
||||
if (gf->resampler_to_whisper) {
|
||||
audio_resampler_destroy(gf->resampler_to_whisper);
|
||||
}
|
||||
|
||||
{
|
||||
@@ -159,87 +150,14 @@ void transcription_filter_destroy(void *data)
|
||||
delete gf->whisper_buf_mutex;
|
||||
delete gf->whisper_ctx_mutex;
|
||||
delete gf->wshiper_thread_cv;
|
||||
delete gf->text_source_mutex;
|
||||
|
||||
delete gf;
|
||||
}
|
||||
|
||||
void acquire_weak_text_source_ref(struct transcription_filter_data *gf)
|
||||
void send_caption_to_source(const std::string &target_source_name, const std::string &str_copy,
|
||||
struct transcription_filter_data *gf)
|
||||
{
|
||||
if (!gf->text_source_name) {
|
||||
obs_log(gf->log_level, "text_source_name is null");
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
|
||||
// acquire a weak ref to the new text source
|
||||
obs_source_t *source = obs_get_source_by_name(gf->text_source_name);
|
||||
if (source) {
|
||||
gf->text_source = obs_source_get_weak_source(source);
|
||||
obs_source_release(source);
|
||||
if (!gf->text_source) {
|
||||
obs_log(LOG_ERROR, "failed to get weak source for text source %s",
|
||||
gf->text_source_name);
|
||||
}
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "text source '%s' not found", gf->text_source_name);
|
||||
}
|
||||
}
|
||||
|
||||
#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0)
|
||||
#define is_trail_byte(c) (((c)&0xc0) == 0x80)
|
||||
|
||||
inline int lead_byte_length(const uint8_t c)
|
||||
{
|
||||
if ((c & 0xe0) == 0xc0) {
|
||||
return 2;
|
||||
} else if ((c & 0xf0) == 0xe0) {
|
||||
return 3;
|
||||
} else if ((c & 0xf8) == 0xf0) {
|
||||
return 4;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_valid_lead_byte(const uint8_t *c)
|
||||
{
|
||||
const int length = lead_byte_length(c[0]);
|
||||
if (length == 1) {
|
||||
return true;
|
||||
}
|
||||
if (length == 2 && is_trail_byte(c[1])) {
|
||||
return true;
|
||||
}
|
||||
if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) {
|
||||
return true;
|
||||
}
|
||||
if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void send_caption_to_source(const std::string &str_copy, struct transcription_filter_data *gf)
|
||||
{
|
||||
if (!gf->text_source_mutex) {
|
||||
obs_log(LOG_ERROR, "text_source_mutex is null");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!gf->text_source) {
|
||||
// attempt to acquire a weak ref to the text source if it's yet available
|
||||
acquire_weak_text_source_ref(gf);
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
|
||||
if (!gf->text_source) {
|
||||
obs_log(gf->log_level, "text_source is null");
|
||||
return;
|
||||
}
|
||||
auto target = obs_weak_source_get_source(gf->text_source);
|
||||
auto target = obs_get_source_by_name(target_source_name.c_str());
|
||||
if (!target) {
|
||||
obs_log(gf->log_level, "text_source target is null");
|
||||
return;
|
||||
@@ -267,52 +185,9 @@ void set_text_callback(struct transcription_filter_data *gf,
|
||||
}
|
||||
gf->last_sub_render_time = now;
|
||||
|
||||
#ifdef _WIN32
|
||||
// Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs
|
||||
// 0xf?, and 0xc? becomes 0xe?, so we need to fix it.
|
||||
std::stringstream ss;
|
||||
uint8_t *c_str = (uint8_t *)result.text.c_str();
|
||||
for (size_t i = 0; i < result.text.size(); ++i) {
|
||||
if (is_lead_byte(c_str[i])) {
|
||||
// this is a unicode leading byte
|
||||
// if the next char is 0xff - it's a bug char, replace it with 0x9f
|
||||
if (c_str[i + 1] == 0xff) {
|
||||
c_str[i + 1] = 0x9f;
|
||||
}
|
||||
if (!is_valid_lead_byte(c_str + i)) {
|
||||
// This is a bug lead byte, because it's length 3 and the i+2 byte is also
|
||||
// a lead byte
|
||||
c_str[i] = c_str[i] - 0x20;
|
||||
}
|
||||
} else {
|
||||
if (c_str[i] >= 0xf8) {
|
||||
// this may be a malformed lead byte.
|
||||
// lets see if it becomes a valid lead byte if we "fix" it
|
||||
uint8_t buf_[4];
|
||||
buf_[0] = c_str[i] - 0x20;
|
||||
buf_[1] = c_str[i + 1];
|
||||
buf_[2] = c_str[i + 2];
|
||||
buf_[3] = c_str[i + 3];
|
||||
if (is_valid_lead_byte(buf_)) {
|
||||
// this is a malformed lead byte, fix it
|
||||
c_str[i] = c_str[i] - 0x20;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string str_copy = (char *)c_str;
|
||||
#else
|
||||
std::string str_copy = result.text;
|
||||
#endif
|
||||
|
||||
// remove trailing spaces, newlines, tabs or punctuation
|
||||
str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(),
|
||||
[](unsigned char ch) {
|
||||
return !std::isspace(ch) || !std::ispunct(ch);
|
||||
})
|
||||
.base(),
|
||||
str_copy.end());
|
||||
// recondition the text
|
||||
std::string str_copy = fix_utf8(result.text);
|
||||
str_copy = remove_leading_trailing_nonalpha(str_copy);
|
||||
|
||||
if (gf->translate) {
|
||||
obs_log(gf->log_level, "Translating text. %s -> %s", gf->source_lang.c_str(),
|
||||
@@ -324,7 +199,13 @@ void set_text_callback(struct transcription_filter_data *gf,
|
||||
obs_log(LOG_INFO, "Translation: '%s' -> '%s'", str_copy.c_str(),
|
||||
translated_text.c_str());
|
||||
}
|
||||
str_copy = translated_text;
|
||||
if (gf->translation_output == "none") {
|
||||
// overwrite the original text with the translated text
|
||||
str_copy = translated_text;
|
||||
} else {
|
||||
// send the translation to the selected source
|
||||
send_caption_to_source(gf->translation_output, translated_text, gf);
|
||||
}
|
||||
} else {
|
||||
obs_log(gf->log_level, "Failed to translate text");
|
||||
}
|
||||
@@ -333,7 +214,7 @@ void set_text_callback(struct transcription_filter_data *gf,
|
||||
gf->last_text = str_copy;
|
||||
|
||||
if (gf->buffered_output) {
|
||||
gf->captions_monitor.addWords(split_words(str_copy));
|
||||
gf->captions_monitor.addWords(result.tokens);
|
||||
}
|
||||
|
||||
if (gf->caption_to_stream) {
|
||||
@@ -344,7 +225,7 @@ void set_text_callback(struct transcription_filter_data *gf,
|
||||
}
|
||||
}
|
||||
|
||||
if (gf->output_file_path != "" && !gf->text_source_name) {
|
||||
if (gf->output_file_path != "" && gf->text_source_name.empty()) {
|
||||
// Check if we should save the sentence
|
||||
if (gf->save_only_while_recording && !obs_frontend_recording_active()) {
|
||||
// We are not recording, do not save the sentence to file
|
||||
@@ -396,7 +277,7 @@ void set_text_callback(struct transcription_filter_data *gf,
|
||||
} else {
|
||||
if (!gf->buffered_output) {
|
||||
// Send the caption to the text source
|
||||
send_caption_to_source(str_copy, gf);
|
||||
send_caption_to_source(gf->text_source_name, str_copy, gf);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -427,12 +308,21 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
gf->process_while_muted = obs_data_get_bool(s, "process_while_muted");
|
||||
gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration");
|
||||
gf->last_sub_render_time = 0;
|
||||
gf->buffered_output = obs_data_get_bool(s, "buffered_output");
|
||||
bool new_buffered_output = obs_data_get_bool(s, "buffered_output");
|
||||
if (new_buffered_output != gf->buffered_output) {
|
||||
gf->buffered_output = new_buffered_output;
|
||||
gf->overlap_ms = gf->buffered_output ? MAX_OVERLAP_SIZE_MSEC
|
||||
: DEFAULT_OVERLAP_SIZE_MSEC;
|
||||
gf->overlap_frames =
|
||||
(size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
|
||||
}
|
||||
|
||||
bool new_translate = obs_data_get_bool(s, "translate");
|
||||
gf->source_lang = obs_data_get_string(s, "translate_source_language");
|
||||
gf->target_lang = obs_data_get_string(s, "translate_target_language");
|
||||
gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context");
|
||||
gf->translation_output = obs_data_get_string(s, "translate_output");
|
||||
gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences");
|
||||
|
||||
if (new_translate != gf->translate) {
|
||||
if (new_translate) {
|
||||
@@ -451,19 +341,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
strcmp(new_text_source_name, "(null)") == 0 ||
|
||||
strcmp(new_text_source_name, "text_file") == 0 || strlen(new_text_source_name) == 0) {
|
||||
// new selected text source is not valid, release the old one
|
||||
if (gf->text_source) {
|
||||
if (!gf->text_source_mutex) {
|
||||
obs_log(LOG_ERROR, "text_source_mutex is null");
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
old_weak_text_source = gf->text_source;
|
||||
gf->text_source = nullptr;
|
||||
}
|
||||
if (gf->text_source_name) {
|
||||
bfree(gf->text_source_name);
|
||||
gf->text_source_name = nullptr;
|
||||
}
|
||||
gf->text_source_name.clear();
|
||||
gf->output_file_path = "";
|
||||
if (strcmp(new_text_source_name, "text_file") == 0) {
|
||||
// set the output file path
|
||||
@@ -475,24 +353,9 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
}
|
||||
} else {
|
||||
// new selected text source is valid, check if it's different from the old one
|
||||
if (gf->text_source_name == nullptr ||
|
||||
strcmp(new_text_source_name, gf->text_source_name) != 0) {
|
||||
if (gf->text_source_name != new_text_source_name) {
|
||||
// new text source is different from the old one, release the old one
|
||||
if (gf->text_source) {
|
||||
if (!gf->text_source_mutex) {
|
||||
obs_log(LOG_ERROR, "text_source_mutex is null");
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
old_weak_text_source = gf->text_source;
|
||||
gf->text_source = nullptr;
|
||||
}
|
||||
if (gf->text_source_name) {
|
||||
// free the old text source name
|
||||
bfree(gf->text_source_name);
|
||||
gf->text_source_name = nullptr;
|
||||
}
|
||||
gf->text_source_name = bstrdup(new_text_source_name);
|
||||
gf->text_source_name = new_text_source_name;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -507,7 +370,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
}
|
||||
|
||||
obs_log(gf->log_level, "update whisper model");
|
||||
update_whsiper_model_path(gf, s);
|
||||
update_whsiper_model(gf, s);
|
||||
|
||||
obs_log(gf->log_level, "update whisper params");
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
@@ -597,15 +460,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
dst.format = AUDIO_FORMAT_FLOAT_PLANAR;
|
||||
dst.speakers = convert_speaker_layout((uint8_t)1);
|
||||
|
||||
gf->resampler = audio_resampler_create(&dst, &src);
|
||||
gf->resampler_to_whisper = audio_resampler_create(&dst, &src);
|
||||
|
||||
obs_log(gf->log_level, "setup mutexes and condition variables");
|
||||
gf->whisper_buf_mutex = new std::mutex();
|
||||
gf->whisper_ctx_mutex = new std::mutex();
|
||||
gf->wshiper_thread_cv = new std::condition_variable();
|
||||
gf->text_source_mutex = new std::mutex();
|
||||
obs_log(gf->log_level, "clear text source data");
|
||||
gf->text_source = nullptr;
|
||||
const char *subtitle_sources = obs_data_get_string(settings, "subtitle_sources");
|
||||
if (subtitle_sources == nullptr || strcmp(subtitle_sources, "none") == 0 ||
|
||||
strcmp(subtitle_sources, "(null)") == 0 || strlen(subtitle_sources) == 0) {
|
||||
@@ -619,8 +480,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
// create a new OBS text source called "LocalVocal Subtitles"
|
||||
obs_source_t *scene_as_source = obs_frontend_get_current_scene();
|
||||
obs_scene_t *scene = obs_scene_from_source(scene_as_source);
|
||||
#ifdef _WIN32
|
||||
source = obs_source_create("text_gdiplus_v2", "LocalVocal Subtitles",
|
||||
nullptr, nullptr);
|
||||
#else
|
||||
source = obs_source_create("text_ft2_source_v2", "LocalVocal Subtitles",
|
||||
nullptr, nullptr);
|
||||
#endif
|
||||
if (source) {
|
||||
// add source to the current scene
|
||||
obs_scene_add(scene, source);
|
||||
@@ -660,11 +526,11 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
}
|
||||
obs_source_release(scene_as_source);
|
||||
}
|
||||
gf->text_source_name = bstrdup("LocalVocal Subtitles");
|
||||
gf->text_source_name = "LocalVocal Subtitles";
|
||||
obs_data_set_string(settings, "subtitle_sources", "LocalVocal Subtitles");
|
||||
} else {
|
||||
// set the text source name
|
||||
gf->text_source_name = bstrdup(subtitle_sources);
|
||||
gf->text_source_name = subtitle_sources;
|
||||
}
|
||||
obs_log(gf->log_level, "clear paths and whisper context");
|
||||
gf->whisper_model_file_currently_loaded = "";
|
||||
@@ -673,13 +539,14 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
gf->whisper_context = nullptr;
|
||||
|
||||
gf->captions_monitor.initialize(
|
||||
gf,
|
||||
[gf](const std::string &text) {
|
||||
obs_log(LOG_INFO, "Captions: %s", text.c_str());
|
||||
if (gf->buffered_output) {
|
||||
send_caption_to_source(text, gf);
|
||||
send_caption_to_source(gf->text_source_name, text, gf);
|
||||
}
|
||||
},
|
||||
20, std::chrono::seconds(10));
|
||||
30, std::chrono::seconds(10));
|
||||
|
||||
obs_log(gf->log_level, "run update");
|
||||
// get the settings updated on the filter data struct
|
||||
@@ -791,6 +658,7 @@ void transcription_filter_defaults(obs_data_t *s)
|
||||
obs_data_set_default_string(s, "translate_target_language", "__es__");
|
||||
obs_data_set_default_string(s, "translate_source_language", "__en__");
|
||||
obs_data_set_default_bool(s, "translate_add_context", true);
|
||||
obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT);
|
||||
|
||||
// Whisper parameters
|
||||
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
@@ -805,6 +673,7 @@ void transcription_filter_defaults(obs_data_t *s)
|
||||
obs_data_set_default_bool(s, "print_realtime", false);
|
||||
obs_data_set_default_bool(s, "print_timestamps", false);
|
||||
obs_data_set_default_bool(s, "token_timestamps", false);
|
||||
obs_data_set_default_bool(s, "dtw_token_timestamps", false);
|
||||
obs_data_set_default_double(s, "thold_pt", 0.01);
|
||||
obs_data_set_default_double(s, "thold_ptsum", 0.01);
|
||||
obs_data_set_default_int(s, "max_len", 0);
|
||||
@@ -919,6 +788,15 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
obs_property_list_add_string(prop_src, language.second.c_str(),
|
||||
language.first.c_str());
|
||||
}
|
||||
// add option for routing the translation to an output source
|
||||
obs_property_t *prop_output = obs_properties_add_list(translation_group, "translate_output",
|
||||
MT_("translate_output"),
|
||||
OBS_COMBO_TYPE_LIST,
|
||||
OBS_COMBO_FORMAT_STRING);
|
||||
obs_property_list_add_string(prop_output, "Write to captions output", "none");
|
||||
// TODO add file output option
|
||||
// obs_property_list_add_string(...
|
||||
obs_enum_sources(add_sources_to_list, prop_output);
|
||||
|
||||
// add callback to enable/disable translation group
|
||||
obs_property_set_modified_callback(translation_group_prop, [](obs_properties_t *props,
|
||||
@@ -928,7 +806,7 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
// Show/Hide the translation group
|
||||
const bool translate_enabled = obs_data_get_bool(settings, "translate");
|
||||
for (const auto &prop : {"translate_target_language", "translate_source_language",
|
||||
"translate_add_context"}) {
|
||||
"translate_add_context", "translate_output"}) {
|
||||
obs_property_set_visible(obs_properties_get(props, prop),
|
||||
translate_enabled);
|
||||
}
|
||||
@@ -946,21 +824,38 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
for (const std::string &prop_name :
|
||||
{"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec",
|
||||
"overlap_size_msec", "step_by_step_processing", "min_sub_duration",
|
||||
"process_while_muted", "buffered_output", "vad_enabled", "log_level"}) {
|
||||
"process_while_muted", "buffered_output", "vad_enabled", "log_level",
|
||||
"suppress_sentences"}) {
|
||||
obs_property_set_visible(obs_properties_get(props, prop_name.c_str()),
|
||||
show_hide);
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output"));
|
||||
obs_property_t *buffered_output_prop =
|
||||
obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output"));
|
||||
// add on-change handler for buffered_output
|
||||
obs_property_set_modified_callback(buffered_output_prop, [](obs_properties_t *props,
|
||||
obs_property_t *property,
|
||||
obs_data_t *settings) {
|
||||
UNUSED_PARAMETER(property);
|
||||
UNUSED_PARAMETER(props);
|
||||
// if buffered output is enabled set the overlap to max else set it to default
|
||||
obs_data_set_int(settings, "overlap_size_msec",
|
||||
obs_data_get_bool(settings, "buffered_output")
|
||||
? MAX_OVERLAP_SIZE_MSEC
|
||||
: DEFAULT_OVERLAP_SIZE_MSEC);
|
||||
return true;
|
||||
});
|
||||
|
||||
obs_properties_add_bool(ppts, "log_words", MT_("log_words"));
|
||||
obs_properties_add_bool(ppts, "caption_to_stream", MT_("caption_to_stream"));
|
||||
|
||||
obs_properties_add_int_slider(ppts, "buffer_size_msec", MT_("buffer_size_msec"), 1000,
|
||||
DEFAULT_BUFFER_SIZE_MSEC, 250);
|
||||
obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"), 50, 300,
|
||||
50);
|
||||
obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"),
|
||||
MIN_OVERLAP_SIZE_MSEC, MAX_OVERLAP_SIZE_MSEC,
|
||||
(MAX_OVERLAP_SIZE_MSEC - MIN_OVERLAP_SIZE_MSEC) / 5);
|
||||
|
||||
obs_property_t *step_by_step_processing = obs_properties_add_bool(
|
||||
ppts, "step_by_step_processing", MT_("step_by_step_processing"));
|
||||
@@ -985,10 +880,14 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
|
||||
obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"),
|
||||
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
|
||||
obs_property_list_add_int(list, "DEBUG", LOG_DEBUG);
|
||||
obs_property_list_add_int(list, "DEBUG (Won't show)", LOG_DEBUG);
|
||||
obs_property_list_add_int(list, "INFO", LOG_INFO);
|
||||
obs_property_list_add_int(list, "WARNING", LOG_WARNING);
|
||||
|
||||
// add a text input for sentences to suppress
|
||||
obs_properties_add_text(ppts, "suppress_sentences", MT_("suppress_sentences"),
|
||||
OBS_TEXT_MULTILINE);
|
||||
|
||||
obs_properties_t *whisper_params_group = obs_properties_create();
|
||||
obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"),
|
||||
OBS_GROUP_NORMAL, whisper_params_group);
|
||||
@@ -1043,6 +942,9 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
obs_properties_add_bool(whisper_params_group, "print_timestamps", MT_("print_timestamps"));
|
||||
// bool token_timestamps; // enable token-level timestamps
|
||||
obs_properties_add_bool(whisper_params_group, "token_timestamps", MT_("token_timestamps"));
|
||||
// enable DTW timestamps
|
||||
obs_properties_add_bool(whisper_params_group, "dtw_token_timestamps",
|
||||
MT_("dtw_token_timestamps"));
|
||||
// float thold_pt; // timestamp token probability threshold (~0.01)
|
||||
obs_properties_add_float_slider(whisper_params_group, "thold_pt", MT_("thold_pt"), 0.0f,
|
||||
1.0f, 0.05f);
|
||||
|
||||
@@ -19,6 +19,8 @@ const char *const PLUGIN_INFO_TEMPLATE =
|
||||
"<a href=\"https://github.com/occ-ai\">OCC AI</a> ❤️ "
|
||||
"<a href=\"https://www.patreon.com/RoyShilkrot\">Support & Follow</a>";
|
||||
|
||||
const char *const SUPPRESS_SENTENCES_DEFAULT = "Thank you for watching\nThank you";
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
104
src/transcription-utils.cpp
Normal file
104
src/transcription-utils.cpp
Normal file
@@ -0,0 +1,104 @@
|
||||
#include "transcription-utils.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
|
||||
#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0)
|
||||
#define is_trail_byte(c) (((c)&0xc0) == 0x80)
|
||||
|
||||
inline int lead_byte_length(const uint8_t c)
|
||||
{
|
||||
if ((c & 0xe0) == 0xc0) {
|
||||
return 2;
|
||||
} else if ((c & 0xf0) == 0xe0) {
|
||||
return 3;
|
||||
} else if ((c & 0xf8) == 0xf0) {
|
||||
return 4;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_valid_lead_byte(const uint8_t *c)
|
||||
{
|
||||
const int length = lead_byte_length(c[0]);
|
||||
if (length == 1) {
|
||||
return true;
|
||||
}
|
||||
if (length == 2 && is_trail_byte(c[1])) {
|
||||
return true;
|
||||
}
|
||||
if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) {
|
||||
return true;
|
||||
}
|
||||
if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string fix_utf8(const std::string &str)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
// Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs
|
||||
// 0xf?, and 0xc? becomes 0xe?, so we need to fix it.
|
||||
std::stringstream ss;
|
||||
uint8_t *c_str = (uint8_t *)str.c_str();
|
||||
for (size_t i = 0; i < str.size(); ++i) {
|
||||
if (is_lead_byte(c_str[i])) {
|
||||
// this is a unicode leading byte
|
||||
// if the next char is 0xff - it's a bug char, replace it with 0x9f
|
||||
if (c_str[i + 1] == 0xff) {
|
||||
c_str[i + 1] = 0x9f;
|
||||
}
|
||||
if (!is_valid_lead_byte(c_str + i)) {
|
||||
// This is a bug lead byte, because it's length 3 and the i+2 byte is also
|
||||
// a lead byte
|
||||
c_str[i] = c_str[i] - 0x20;
|
||||
}
|
||||
} else {
|
||||
if (c_str[i] >= 0xf8) {
|
||||
// this may be a malformed lead byte.
|
||||
// lets see if it becomes a valid lead byte if we "fix" it
|
||||
uint8_t buf_[4];
|
||||
buf_[0] = c_str[i] - 0x20;
|
||||
buf_[1] = c_str[i + 1];
|
||||
buf_[2] = c_str[i + 2];
|
||||
buf_[3] = c_str[i + 3];
|
||||
if (is_valid_lead_byte(buf_)) {
|
||||
// this is a malformed lead byte, fix it
|
||||
c_str[i] = c_str[i] - 0x20;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string((char *)c_str);
|
||||
#else
|
||||
return str;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Remove leading and trailing non-alphabetic characters from a string.
|
||||
* This function is used to remove leading and trailing spaces, newlines, tabs or punctuation.
|
||||
* @param str: the string to remove leading and trailing non-alphabetic characters from.
|
||||
* @return: the string with leading and trailing non-alphabetic characters removed.
|
||||
*/
|
||||
std::string remove_leading_trailing_nonalpha(const std::string &str)
|
||||
{
|
||||
std::string str_copy = str;
|
||||
// remove trailing spaces, newlines, tabs or punctuation
|
||||
str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(),
|
||||
[](unsigned char ch) {
|
||||
return !std::isspace(ch) || !std::ispunct(ch);
|
||||
})
|
||||
.base(),
|
||||
str_copy.end());
|
||||
// remove leading spaces, newlines, tabs or punctuation
|
||||
str_copy.erase(str_copy.begin(),
|
||||
std::find_if(str_copy.begin(), str_copy.end(), [](unsigned char ch) {
|
||||
return !std::isspace(ch) || !std::ispunct(ch);
|
||||
}));
|
||||
return str_copy;
|
||||
}
|
||||
9
src/transcription-utils.h
Normal file
9
src/transcription-utils.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#ifndef TRANSCRIPTION_UTILS_H
|
||||
#define TRANSCRIPTION_UTILS_H
|
||||
|
||||
#include <string>
|
||||
|
||||
std::string fix_utf8(const std::string &str);
|
||||
std::string remove_leading_trailing_nonalpha(const std::string &str);
|
||||
|
||||
#endif // TRANSCRIPTION_UTILS_H
|
||||
@@ -10,7 +10,10 @@
|
||||
#include <cstdio>
|
||||
#include <cstdarg>
|
||||
|
||||
//#define __DEBUG_SPEECH_PROB___
|
||||
#include <obs.h>
|
||||
#include "plugin-support.h"
|
||||
|
||||
// #define __DEBUG_SPEECH_PROB___
|
||||
|
||||
timestamp_t::timestamp_t(int start_, int end_) : start(start_), end(end_){};
|
||||
|
||||
@@ -144,8 +147,8 @@ void VadIterator::predict(const std::vector<float> &data)
|
||||
float speech =
|
||||
current_sample -
|
||||
window_size_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob,
|
||||
current_sample - window_size_samples);
|
||||
obs_log(LOG_INFO, "{ start: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate,
|
||||
speech_prob, current_sample - window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
if (temp_end != 0) {
|
||||
temp_end = 0;
|
||||
@@ -194,16 +197,18 @@ void VadIterator::predict(const std::vector<float> &data)
|
||||
float speech =
|
||||
current_sample -
|
||||
window_size_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate,
|
||||
speech_prob, current_sample - window_size_samples);
|
||||
obs_log(LOG_INFO, "{ speaking: %.3f s (%.3f) %08d}",
|
||||
1.0 * speech / sample_rate, speech_prob,
|
||||
current_sample - window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
} else {
|
||||
#ifdef __DEBUG_SPEECH_PROB___
|
||||
float speech =
|
||||
current_sample -
|
||||
window_size_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate,
|
||||
speech_prob, current_sample - window_size_samples);
|
||||
obs_log(LOG_INFO, "{ silence: %.3f s (%.3f) %08d}",
|
||||
1.0 * speech / sample_rate, speech_prob,
|
||||
current_sample - window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
}
|
||||
return;
|
||||
@@ -215,8 +220,8 @@ void VadIterator::predict(const std::vector<float> &data)
|
||||
float speech =
|
||||
current_sample - window_size_samples -
|
||||
speech_pad_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob,
|
||||
current_sample - window_size_samples);
|
||||
obs_log(LOG_INFO, "{ end: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate,
|
||||
speech_prob, current_sample - window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
if (triggered == true) {
|
||||
if (temp_end == 0) {
|
||||
@@ -285,7 +290,7 @@ void VadIterator::collect_chunks(const std::vector<float> &input_wav,
|
||||
output_wav.clear();
|
||||
for (size_t i = 0; i < speeches.size(); i++) {
|
||||
#ifdef __DEBUG_SPEECH_PROB___
|
||||
std::cout << speeches[i].c_str() << std::endl;
|
||||
obs_log(LOG_INFO, "%s", speeches[i].string().c_str());
|
||||
#endif //#ifdef __DEBUG_SPEECH_PROB___
|
||||
std::vector<float> slice(&input_wav[speeches[i].start],
|
||||
&input_wav[speeches[i].end]);
|
||||
|
||||
131
src/whisper-utils/token-buffer-thread.cpp
Normal file
131
src/whisper-utils/token-buffer-thread.cpp
Normal file
@@ -0,0 +1,131 @@
|
||||
#include "token-buffer-thread.h"
|
||||
#include "./whisper-utils.h"
|
||||
|
||||
TokenBufferThread::~TokenBufferThread()
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(queueMutex);
|
||||
stop = true;
|
||||
}
|
||||
condVar.notify_all();
|
||||
workerThread.join();
|
||||
}
|
||||
|
||||
void TokenBufferThread::initialize(struct transcription_filter_data *gf_,
|
||||
std::function<void(const std::string &)> callback_,
|
||||
size_t maxSize_, std::chrono::seconds maxTime_)
|
||||
{
|
||||
this->gf = gf_;
|
||||
this->callback = callback_;
|
||||
this->maxSize = maxSize_;
|
||||
this->maxTime = maxTime_;
|
||||
this->initialized = true;
|
||||
this->workerThread = std::thread(&TokenBufferThread::monitor, this);
|
||||
}
|
||||
|
||||
void TokenBufferThread::log_token_vector(const std::vector<whisper_token_data> &tokens)
|
||||
{
|
||||
std::string output;
|
||||
for (const auto &token : tokens) {
|
||||
const char *token_str = whisper_token_to_str(gf->whisper_context, token.id);
|
||||
output += token_str;
|
||||
}
|
||||
obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str());
|
||||
}
|
||||
|
||||
void TokenBufferThread::addWords(const std::vector<whisper_token_data> &words)
|
||||
{
|
||||
obs_log(LOG_INFO, "TokenBufferThread::addWords");
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(queueMutex);
|
||||
|
||||
// convert current wordQueue to vector
|
||||
std::vector<whisper_token_data> currentWords(wordQueue.begin(), wordQueue.end());
|
||||
|
||||
log_token_vector(currentWords);
|
||||
log_token_vector(words);
|
||||
|
||||
// run reconstructSentence
|
||||
std::vector<whisper_token_data> reconstructed =
|
||||
reconstructSentence(currentWords, words);
|
||||
|
||||
log_token_vector(reconstructed);
|
||||
|
||||
// clear the wordQueue
|
||||
wordQueue.clear();
|
||||
|
||||
// add the reconstructed sentence to the wordQueue
|
||||
for (const auto &word : reconstructed) {
|
||||
wordQueue.push_back(word);
|
||||
}
|
||||
|
||||
newDataAvailable = true;
|
||||
}
|
||||
condVar.notify_all();
|
||||
}
|
||||
|
||||
void TokenBufferThread::monitor()
|
||||
{
|
||||
obs_log(LOG_INFO, "TokenBufferThread::monitor");
|
||||
auto startTime = std::chrono::steady_clock::now();
|
||||
while (this->initialized && !this->stop) {
|
||||
std::unique_lock<std::mutex> lock(this->queueMutex);
|
||||
// wait for new data or stop signal
|
||||
this->condVar.wait(lock, [this] { return this->newDataAvailable || this->stop; });
|
||||
|
||||
if (this->stop) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (this->wordQueue.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this->gf->whisper_context == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// emit up to maxSize words from the wordQueue
|
||||
std::vector<whisper_token_data> emitted;
|
||||
while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) {
|
||||
emitted.push_back(this->wordQueue.front());
|
||||
this->wordQueue.pop_front();
|
||||
}
|
||||
obs_log(LOG_INFO, "TokenBufferThread::monitor: emitting %d words", emitted.size());
|
||||
log_token_vector(emitted);
|
||||
// emit the caption from the tokens
|
||||
std::string output;
|
||||
for (const auto &token : emitted) {
|
||||
const char *token_str =
|
||||
whisper_token_to_str(this->gf->whisper_context, token.id);
|
||||
output += token_str;
|
||||
}
|
||||
this->callback(output);
|
||||
// push back the words that were emitted, in reverse order
|
||||
for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) {
|
||||
this->wordQueue.push_front(*it);
|
||||
}
|
||||
|
||||
// check if we need to flush the queue
|
||||
auto elapsedTime = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - startTime);
|
||||
if (this->wordQueue.size() >= this->maxSize || elapsedTime >= this->maxTime) {
|
||||
// flush the queue if it's full or we've reached the max time
|
||||
size_t words_to_flush = std::min(this->wordQueue.size(), this->maxSize);
|
||||
// make sure we leave at least 3 words in the queue
|
||||
size_t words_remaining = this->wordQueue.size() - words_to_flush;
|
||||
if (words_remaining < 3) {
|
||||
words_to_flush -= 3 - words_remaining;
|
||||
}
|
||||
obs_log(LOG_INFO, "TokenBufferThread::monitor: flushing %d words",
|
||||
words_to_flush);
|
||||
for (size_t i = 0; i < words_to_flush; ++i) {
|
||||
wordQueue.pop_front();
|
||||
}
|
||||
startTime = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
newDataAvailable = false;
|
||||
}
|
||||
obs_log(LOG_INFO, "TokenBufferThread::monitor: done");
|
||||
}
|
||||
49
src/whisper-utils/token-buffer-thread.h
Normal file
49
src/whisper-utils/token-buffer-thread.h
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef TOKEN_BUFFER_THREAD_H
|
||||
#define TOKEN_BUFFER_THREAD_H
|
||||
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include <obs.h>
|
||||
|
||||
#include <whisper.h>
|
||||
|
||||
#include "plugin-support.h"
|
||||
|
||||
struct transcription_filter_data;
|
||||
|
||||
class TokenBufferThread {
|
||||
public:
|
||||
// default constructor
|
||||
TokenBufferThread() = default;
|
||||
|
||||
~TokenBufferThread();
|
||||
void initialize(struct transcription_filter_data *gf,
|
||||
std::function<void(const std::string &)> callback_, size_t maxSize_,
|
||||
std::chrono::seconds maxTime_);
|
||||
|
||||
void addWords(const std::vector<whisper_token_data> &words);
|
||||
|
||||
private:
|
||||
void monitor();
|
||||
void log_token_vector(const std::vector<whisper_token_data> &tokens);
|
||||
struct transcription_filter_data *gf;
|
||||
std::deque<whisper_token_data> wordQueue;
|
||||
std::thread workerThread;
|
||||
std::mutex queueMutex;
|
||||
std::condition_variable condVar;
|
||||
std::function<void(std::string)> callback;
|
||||
size_t maxSize;
|
||||
std::chrono::seconds maxTime;
|
||||
bool stop;
|
||||
bool initialized = false;
|
||||
bool newDataAvailable = false;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "plugin-support.h"
|
||||
#include "transcription-filter-data.h"
|
||||
#include "whisper-processing.h"
|
||||
#include "whisper-utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
@@ -109,7 +110,8 @@ bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float v
|
||||
return true;
|
||||
}
|
||||
|
||||
struct whisper_context *init_whisper_context(const std::string &model_path_in)
|
||||
struct whisper_context *init_whisper_context(const std::string &model_path_in,
|
||||
struct transcription_filter_data *gf)
|
||||
{
|
||||
std::string model_path = model_path_in;
|
||||
|
||||
@@ -131,14 +133,15 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in)
|
||||
whisper_log_set(
|
||||
[](enum ggml_log_level level, const char *text, void *user_data) {
|
||||
UNUSED_PARAMETER(level);
|
||||
UNUSED_PARAMETER(user_data);
|
||||
struct transcription_filter_data *ctx =
|
||||
static_cast<struct transcription_filter_data *>(user_data);
|
||||
// remove trailing newline
|
||||
char *text_copy = bstrdup(text);
|
||||
text_copy[strcspn(text_copy, "\n")] = 0;
|
||||
obs_log(LOG_INFO, "Whisper: %s", text_copy);
|
||||
obs_log(ctx->log_level, "Whisper: %s", text_copy);
|
||||
bfree(text_copy);
|
||||
},
|
||||
nullptr);
|
||||
gf);
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
#ifdef LOCALVOCAL_WITH_CUDA
|
||||
@@ -152,6 +155,16 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in)
|
||||
obs_log(LOG_INFO, "Using CPU for inference");
|
||||
#endif
|
||||
|
||||
cparams.dtw_token_timestamps = gf->enable_token_ts_dtw;
|
||||
if (gf->enable_token_ts_dtw) {
|
||||
obs_log(LOG_INFO, "DTW token timestamps enabled");
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
|
||||
// cparams.dtw_n_top = 4;
|
||||
} else {
|
||||
obs_log(LOG_INFO, "DTW token timestamps disabled");
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
|
||||
}
|
||||
|
||||
struct whisper_context *ctx = nullptr;
|
||||
try {
|
||||
#ifdef _WIN32
|
||||
@@ -196,16 +209,19 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in)
|
||||
}
|
||||
|
||||
struct DetectionResultWithText run_whisper_inference(struct transcription_filter_data *gf,
|
||||
const float *pcm32f_data, size_t pcm32f_size)
|
||||
const float *pcm32f_data, size_t pcm32f_size,
|
||||
bool zero_start)
|
||||
{
|
||||
UNUSED_PARAMETER(zero_start);
|
||||
|
||||
if (gf == nullptr) {
|
||||
obs_log(LOG_ERROR, "run_whisper_inference: gf is null");
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
|
||||
}
|
||||
|
||||
if (pcm32f_data == nullptr || pcm32f_size == 0) {
|
||||
obs_log(LOG_ERROR, "run_whisper_inference: pcm32f_data is null or size is 0");
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
|
||||
}
|
||||
|
||||
obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__,
|
||||
@@ -215,7 +231,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_WARNING, "whisper context is null");
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
|
||||
}
|
||||
|
||||
// Get the duration in ms since the beginning of the stream (gf->start_timestamp_ms)
|
||||
@@ -234,47 +250,92 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
|
||||
obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what());
|
||||
whisper_free(gf->whisper_context);
|
||||
gf->whisper_context = nullptr;
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
|
||||
}
|
||||
|
||||
if (whisper_full_result != 0) {
|
||||
obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result);
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0};
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
|
||||
} else {
|
||||
// duration in ms
|
||||
const uint64_t duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE);
|
||||
|
||||
const int n_segment = 0;
|
||||
const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment);
|
||||
// const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment);
|
||||
const int64_t t0 = offset_ms;
|
||||
const int64_t t1 = offset_ms + duration_ms;
|
||||
|
||||
float sentence_p = 0.0f;
|
||||
const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment);
|
||||
std::string text = "";
|
||||
std::string tokenIds = "";
|
||||
std::vector<whisper_token_data> tokens;
|
||||
bool end = false;
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
sentence_p += whisper_full_get_token_p(gf->whisper_context, n_segment, j);
|
||||
// get token
|
||||
whisper_token_data token =
|
||||
whisper_full_get_token_data(gf->whisper_context, n_segment, j);
|
||||
const char *token_str = whisper_token_to_str(gf->whisper_context, token.id);
|
||||
bool keep = !end;
|
||||
// if the token starts with '[' and ends with ']', don't keep it
|
||||
if (token_str[0] == '[' && token_str[strlen(token_str) - 1] == ']') {
|
||||
keep = false;
|
||||
}
|
||||
if ((j == n_tokens - 2 || j == n_tokens - 3) && token.p < 0.5) {
|
||||
keep = false;
|
||||
}
|
||||
// if the second to last token is .id == 13 ('.'), don't keep it
|
||||
if (j == n_tokens - 2 && token.id == 13) {
|
||||
keep = false;
|
||||
}
|
||||
// token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json
|
||||
if (token.id > 50540 && token.id <= 51865) {
|
||||
obs_log(gf->log_level,
|
||||
"Large time token found (%d), this shouldn't happen",
|
||||
token.id);
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
|
||||
}
|
||||
|
||||
if (keep) {
|
||||
text += token_str;
|
||||
tokenIds += std::to_string(token.id) + " (" +
|
||||
std::string(token_str) + "), ";
|
||||
tokens.push_back(token);
|
||||
}
|
||||
obs_log(gf->log_level, "Token %d: %d, %s, p: %.3f, dtw: %ld [keep: %d]", j,
|
||||
token.id, token_str, token.p, token.t_dtw, keep);
|
||||
}
|
||||
sentence_p /= (float)n_tokens;
|
||||
obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str());
|
||||
obs_log(gf->log_level, "Token IDs: %s", tokenIds.c_str());
|
||||
|
||||
// convert text to lowercase
|
||||
std::string text_lower(text);
|
||||
std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower);
|
||||
// trim whitespace (use lambda)
|
||||
text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(),
|
||||
[](unsigned char ch) { return !std::isspace(ch); })
|
||||
.base(),
|
||||
text_lower.end());
|
||||
// if suppression is enabled, check if the text is in the suppression list
|
||||
if (!gf->suppress_sentences.empty()) {
|
||||
std::string suppress_sentences_copy = gf->suppress_sentences;
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
while ((pos = suppress_sentences_copy.find("\n")) != std::string::npos) {
|
||||
token = suppress_sentences_copy.substr(0, pos);
|
||||
suppress_sentences_copy.erase(0, pos + 1);
|
||||
if (text == suppress_sentences_copy) {
|
||||
obs_log(gf->log_level, "Suppressing sentence: %s",
|
||||
text.c_str());
|
||||
return {DETECTION_RESULT_SUPPRESSED, "", 0, 0, {}};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (gf->log_words) {
|
||||
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
|
||||
to_timestamp(t1).c_str(), sentence_p, text_lower.c_str());
|
||||
to_timestamp(t1).c_str(), sentence_p, text.c_str());
|
||||
}
|
||||
|
||||
if (text_lower.empty() || text_lower == ".") {
|
||||
return {DETECTION_RESULT_SILENCE, "", 0, 0};
|
||||
if (text.empty() || text == ".") {
|
||||
return {DETECTION_RESULT_SILENCE, "", 0, 0, {}};
|
||||
}
|
||||
|
||||
return {DETECTION_RESULT_SPEECH, text_lower, offset_ms, offset_ms + duration_ms};
|
||||
return {DETECTION_RESULT_SPEECH, text, offset_ms, offset_ms + duration_ms, tokens};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,14 +368,14 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
num_new_frames_from_infos -= info_from_buf.frames;
|
||||
circlebuf_push_front(&gf->info_buffer, &info_from_buf,
|
||||
size_of_audio_info);
|
||||
last_step_in_segment =
|
||||
true; // this is the final step in the segment
|
||||
// this is the final step in the segment
|
||||
last_step_in_segment = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
obs_log(gf->log_level,
|
||||
"with %lu remaining to full segment, popped %d info-frames, pushing into buffer at %lu",
|
||||
"with %lu remaining to full segment, popped %d info-frames, pushing at %lu (overlap)",
|
||||
remaining_frames_to_full_segment, num_new_frames_from_infos,
|
||||
gf->last_num_frames);
|
||||
|
||||
@@ -340,7 +401,7 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
}
|
||||
} else {
|
||||
gf->last_num_frames = num_new_frames_from_infos;
|
||||
obs_log(gf->log_level, "first segment, %d frames to process",
|
||||
obs_log(gf->log_level, "first segment, no overlap exists, %d frames to process",
|
||||
(int)(gf->last_num_frames));
|
||||
}
|
||||
|
||||
@@ -352,51 +413,92 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// resample to 16kHz
|
||||
float *output[MAX_PREPROC_CHANNELS];
|
||||
uint32_t out_frames;
|
||||
float *resampled_16khz[MAX_PREPROC_CHANNELS];
|
||||
uint32_t resampled_16khz_frames;
|
||||
uint64_t ts_offset;
|
||||
audio_resampler_resample(gf->resampler, (uint8_t **)output, &out_frames, &ts_offset,
|
||||
audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz,
|
||||
&resampled_16khz_frames, &ts_offset,
|
||||
(const uint8_t **)gf->copy_buffers, (uint32_t)gf->last_num_frames);
|
||||
|
||||
obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels, (int)out_frames,
|
||||
(float)out_frames / WHISPER_SAMPLE_RATE * 1000.0f);
|
||||
obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels,
|
||||
(int)resampled_16khz_frames,
|
||||
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
|
||||
|
||||
bool skipped_inference = false;
|
||||
uint32_t speech_start_frame = 0;
|
||||
uint32_t speech_end_frame = out_frames;
|
||||
uint32_t speech_end_frame = resampled_16khz_frames;
|
||||
|
||||
if (gf->vad_enabled) {
|
||||
std::vector<float> vad_input(output[0], output[0] + out_frames);
|
||||
std::vector<float> vad_input(resampled_16khz[0],
|
||||
resampled_16khz[0] + resampled_16khz_frames);
|
||||
gf->vad->process(vad_input);
|
||||
|
||||
auto stamps = gf->vad->get_speech_timestamps();
|
||||
std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
|
||||
if (stamps.size() == 0) {
|
||||
obs_log(gf->log_level, "VAD detected no speech in %d frames",
|
||||
resampled_16khz_frames);
|
||||
skipped_inference = true;
|
||||
// prevent copying the buffer to the beginning (overlap)
|
||||
gf->last_num_frames = 0;
|
||||
last_step_in_segment = false;
|
||||
} else {
|
||||
speech_start_frame = stamps[0].start;
|
||||
speech_start_frame = (stamps[0].start < 3000) ? 0 : stamps[0].start;
|
||||
speech_end_frame = stamps.back().end;
|
||||
obs_log(gf->log_level, "VAD detected speech from %d to %d",
|
||||
speech_start_frame, speech_end_frame);
|
||||
uint32_t number_of_frames = speech_end_frame - speech_start_frame;
|
||||
|
||||
obs_log(gf->log_level,
|
||||
"VAD detected speech from %d to %d (%d frames, %d ms)",
|
||||
speech_start_frame, speech_end_frame, number_of_frames,
|
||||
number_of_frames * 1000 / WHISPER_SAMPLE_RATE);
|
||||
|
||||
// if the speech segment is less than 1 second - put the audio back into the buffer
|
||||
// to be handled in the next iteration
|
||||
if (number_of_frames > 0 && number_of_frames < WHISPER_SAMPLE_RATE) {
|
||||
// convert speech_start_frame and speech_end_frame to original sample rate
|
||||
speech_start_frame =
|
||||
speech_start_frame * gf->sample_rate / WHISPER_SAMPLE_RATE;
|
||||
speech_end_frame =
|
||||
speech_end_frame * gf->sample_rate / WHISPER_SAMPLE_RATE;
|
||||
number_of_frames = speech_end_frame - speech_start_frame;
|
||||
|
||||
// use memmove to copy the speech segment to the beginning of the buffer
|
||||
for (size_t c = 0; c < gf->channels; c++) {
|
||||
memmove(gf->copy_buffers[c],
|
||||
gf->copy_buffers[c] + speech_start_frame,
|
||||
number_of_frames * sizeof(float));
|
||||
}
|
||||
|
||||
obs_log(gf->log_level,
|
||||
"Speech segment is less than 1 second, moving %d to %d (len %d) to buffer start",
|
||||
speech_start_frame, speech_end_frame, number_of_frames);
|
||||
// no processing of the segment
|
||||
skipped_inference = true;
|
||||
// reset the last_num_frames to the number of frames in the buffer
|
||||
gf->last_num_frames = number_of_frames;
|
||||
// prevent copying the buffer to the beginning (overlap)
|
||||
last_step_in_segment = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!skipped_inference) {
|
||||
// run inference
|
||||
const struct DetectionResultWithText inference_result = run_whisper_inference(
|
||||
gf, output[0] + speech_start_frame, speech_end_frame - speech_start_frame);
|
||||
gf, resampled_16khz[0] + speech_start_frame,
|
||||
speech_end_frame - speech_start_frame, speech_start_frame == 0);
|
||||
|
||||
if (inference_result.result == DETECTION_RESULT_SPEECH) {
|
||||
// output inference result to a text source
|
||||
set_text_callback(gf, inference_result);
|
||||
} else if (inference_result.result == DETECTION_RESULT_SILENCE) {
|
||||
// output inference result to a text source
|
||||
set_text_callback(gf, {inference_result.result, "[silence]", 0, 0});
|
||||
set_text_callback(gf, {inference_result.result, "[silence]", 0, 0, {}});
|
||||
}
|
||||
} else {
|
||||
if (gf->log_words) {
|
||||
obs_log(LOG_INFO, "skipping inference");
|
||||
}
|
||||
set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0});
|
||||
set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0, {}});
|
||||
}
|
||||
|
||||
// end of timer
|
||||
@@ -407,6 +509,12 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
(int)duration);
|
||||
|
||||
if (last_step_in_segment) {
|
||||
const uint64_t overlap_size_ms =
|
||||
(uint64_t)(gf->overlap_frames * 1000 / gf->sample_rate);
|
||||
obs_log(gf->log_level,
|
||||
"copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning",
|
||||
gf->overlap_frames, overlap_size_ms,
|
||||
gf->last_num_frames - gf->overlap_frames);
|
||||
for (size_t c = 0; c < gf->channels; c++) {
|
||||
// This is the last step in the segment - reset the copy buffer (include overlap frames)
|
||||
// move overlap frames from the end of the last copy_buffers to the beginning
|
||||
@@ -416,8 +524,8 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
// zero out the rest of the buffer, just in case
|
||||
memset(gf->copy_buffers[c] + gf->overlap_frames, 0,
|
||||
(gf->frames - gf->overlap_frames) * sizeof(float));
|
||||
gf->last_num_frames = gf->overlap_frames;
|
||||
}
|
||||
gf->last_num_frames = gf->overlap_frames;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
#ifndef WHISPER_PROCESSING_H
|
||||
#define WHISPER_PROCESSING_H
|
||||
|
||||
#include <whisper.h>
|
||||
|
||||
// buffer size in msec
|
||||
#define DEFAULT_BUFFER_SIZE_MSEC 2000
|
||||
#define DEFAULT_BUFFER_SIZE_MSEC 3000
|
||||
// overlap in msec
|
||||
#define DEFAULT_OVERLAP_SIZE_MSEC 100
|
||||
#define MAX_OVERLAP_SIZE_MSEC 1000
|
||||
#define MIN_OVERLAP_SIZE_MSEC 100
|
||||
|
||||
enum DetectionResult {
|
||||
DETECTION_RESULT_UNKNOWN = 0,
|
||||
DETECTION_RESULT_SILENCE = 1,
|
||||
DETECTION_RESULT_SPEECH = 2,
|
||||
DETECTION_RESULT_SUPPRESSED = 3,
|
||||
};
|
||||
|
||||
struct DetectionResultWithText {
|
||||
DetectionResult result;
|
||||
std::string text;
|
||||
uint64_t start_timestamp_ms;
|
||||
uint64_t end_timestamp_ms;
|
||||
std::vector<whisper_token_data> tokens;
|
||||
};
|
||||
|
||||
void whisper_loop(void *data);
|
||||
struct whisper_context *init_whisper_context(const std::string &model_path);
|
||||
struct whisper_context *init_whisper_context(const std::string &model_path,
|
||||
struct transcription_filter_data *gf);
|
||||
|
||||
#endif // WHISPER_PROCESSING_H
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
void update_whsiper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
{
|
||||
// update the whisper model path
|
||||
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
|
||||
@@ -13,9 +13,12 @@ void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t
|
||||
|
||||
if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
|
||||
is_external_model) {
|
||||
// model path changed, reload the model
|
||||
obs_log(gf->log_level, "model path changed from %s to %s",
|
||||
gf->whisper_model_path.c_str(), new_model_path.c_str());
|
||||
|
||||
if (gf->whisper_model_path != new_model_path) {
|
||||
// model path changed
|
||||
obs_log(gf->log_level, "model path changed from %s to %s",
|
||||
gf->whisper_model_path.c_str(), new_model_path.c_str());
|
||||
}
|
||||
|
||||
// check if the new model is external file
|
||||
if (!is_external_model) {
|
||||
@@ -76,6 +79,21 @@ void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t
|
||||
obs_log(gf->log_level, "Model path did not change: %s == %s",
|
||||
gf->whisper_model_path.c_str(), new_model_path.c_str());
|
||||
}
|
||||
|
||||
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
|
||||
|
||||
if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
|
||||
// dtw_token_timestamps changed
|
||||
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
|
||||
gf->enable_token_ts_dtw, new_dtw_timestamps);
|
||||
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
|
||||
shutdown_whisper_thread(gf);
|
||||
start_whisper_thread_with_path(gf, gf->whisper_model_path);
|
||||
} else {
|
||||
// dtw_token_timestamps did not change
|
||||
obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d",
|
||||
gf->enable_token_ts_dtw, new_dtw_timestamps);
|
||||
}
|
||||
}
|
||||
|
||||
void shutdown_whisper_thread(struct transcription_filter_data *gf)
|
||||
@@ -122,9 +140,12 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const
|
||||
#else
|
||||
std::string silero_vad_model_path = silero_vad_model_file;
|
||||
#endif
|
||||
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE));
|
||||
// roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
|
||||
// for silero vad parameters
|
||||
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 1000,
|
||||
200, 250));
|
||||
|
||||
gf->whisper_context = init_whisper_context(path);
|
||||
gf->whisper_context = init_whisper_context(path, gf);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_ERROR, "Failed to initialize whisper context");
|
||||
return;
|
||||
@@ -133,3 +154,100 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const
|
||||
std::thread new_whisper_thread(whisper_loop, gf);
|
||||
gf->whisper_thread.swap(new_whisper_thread);
|
||||
}
|
||||
|
||||
// Finds start of 2-token overlap between two sequences of tokens
|
||||
// Returns a pair of indices of the first overlapping tokens in the two sequences
|
||||
// If no overlap is found, the function returns {-1, -1}
|
||||
// Allows for a single token mismatch in the overlap
|
||||
std::pair<int, int> findStartOfOverlap(const std::vector<whisper_token_data> &seq1,
|
||||
const std::vector<whisper_token_data> &seq2)
|
||||
{
|
||||
if (seq1.empty() || seq2.empty() || seq1.size() == 1 || seq2.size() == 1) {
|
||||
return {-1, -1};
|
||||
}
|
||||
for (size_t i = seq1.size() - 2; i >= seq1.size() / 2; --i) {
|
||||
for (size_t j = 0; j < seq2.size() - 1; ++j) {
|
||||
if (seq1[i].id == seq2[j].id) {
|
||||
// Check if the next token in both sequences is the same
|
||||
if (seq1[i + 1].id == seq2[j + 1].id) {
|
||||
return {i, j};
|
||||
}
|
||||
// 1-skip check on seq1
|
||||
if (i + 2 < seq1.size() && seq1[i + 2].id == seq2[j + 1].id) {
|
||||
return {i, j};
|
||||
}
|
||||
// 1-skip check on seq2
|
||||
if (j + 2 < seq2.size() && seq1[i + 1].id == seq2[j + 2].id) {
|
||||
return {i, j};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return {-1, -1};
|
||||
}
|
||||
|
||||
// Function to reconstruct a whole sentence from two sentences using overlap info
|
||||
// If no overlap is found, the function returns the concatenation of the two sequences
|
||||
std::vector<whisper_token_data> reconstructSentence(const std::vector<whisper_token_data> &seq1,
|
||||
const std::vector<whisper_token_data> &seq2)
|
||||
{
|
||||
auto overlap = findStartOfOverlap(seq1, seq2);
|
||||
std::vector<whisper_token_data> reconstructed;
|
||||
|
||||
if (overlap.first == -1 || overlap.second == -1) {
|
||||
if (seq1.empty() && seq2.empty()) {
|
||||
return reconstructed;
|
||||
}
|
||||
if (seq1.empty()) {
|
||||
return seq2;
|
||||
}
|
||||
if (seq2.empty()) {
|
||||
return seq1;
|
||||
}
|
||||
|
||||
// Return concat of seq1 and seq2 if no overlap found
|
||||
// check if the last token of seq1 == the first token of seq2
|
||||
if (seq1.back().id == seq2.front().id) {
|
||||
// don't add the last token of seq1
|
||||
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1);
|
||||
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
|
||||
} else if (seq2.size() > 1ull && seq1.back().id == seq2[1].id) {
|
||||
// check if the last token of seq1 == the second token of seq2
|
||||
// don't add the last token of seq1
|
||||
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1);
|
||||
// don't add the first token of seq2
|
||||
reconstructed.insert(reconstructed.end(), seq2.begin() + 1, seq2.end());
|
||||
} else if (seq1.size() > 1ull && seq1[seq1.size() - 2].id == seq2.front().id) {
|
||||
// check if the second to last token of seq1 == the first token of seq2
|
||||
// don't add the last two tokens of seq1
|
||||
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 2);
|
||||
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
|
||||
} else {
|
||||
// add all tokens of seq1
|
||||
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end());
|
||||
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
|
||||
}
|
||||
return reconstructed;
|
||||
}
|
||||
|
||||
// Add tokens from the first sequence up to the overlap
|
||||
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.begin() + overlap.first);
|
||||
|
||||
// Determine the length of the overlap
|
||||
size_t overlapLength = 0;
|
||||
while (overlap.first + overlapLength < seq1.size() &&
|
||||
overlap.second + overlapLength < seq2.size() &&
|
||||
seq1[overlap.first + overlapLength].id == seq2[overlap.second + overlapLength].id) {
|
||||
overlapLength++;
|
||||
}
|
||||
|
||||
// Add overlapping tokens
|
||||
reconstructed.insert(reconstructed.end(), seq1.begin() + overlap.first,
|
||||
seq1.begin() + overlap.first + overlapLength);
|
||||
|
||||
// Add remaining tokens from the second sequence
|
||||
reconstructed.insert(reconstructed.end(), seq2.begin() + overlap.second + overlapLength,
|
||||
seq2.end());
|
||||
|
||||
return reconstructed;
|
||||
}
|
||||
|
||||
@@ -7,8 +7,13 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s);
|
||||
void update_whsiper_model(struct transcription_filter_data *gf, obs_data_t *s);
|
||||
void shutdown_whisper_thread(struct transcription_filter_data *gf);
|
||||
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path);
|
||||
|
||||
std::pair<int, int> findStartOfOverlap(const std::vector<whisper_token_data> &seq1,
|
||||
const std::vector<whisper_token_data> &seq2);
|
||||
std::vector<whisper_token_data> reconstructSentence(const std::vector<whisper_token_data> &seq1,
|
||||
const std::vector<whisper_token_data> &seq2);
|
||||
|
||||
#endif /* WHISPER_UTILS_H */
|
||||
|
||||
Reference in New Issue
Block a user