Merge pull request #1 from royshil/roy.add_threaded_whisper_cpp

Add threaded whisper cpp
This commit is contained in:
Roy Shilkrot
2023-08-14 09:55:48 +03:00
committed by GitHub
29 changed files with 1875 additions and 22 deletions

View File

@@ -44,7 +44,7 @@ BreakBeforeBraces: Custom
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeColon
BreakStringLiterals: false # apparently unpredictable
ColumnLimit: 80
ColumnLimit: 100
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 8
@@ -53,7 +53,7 @@ Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
FixNamespaceComments: false
ForEachMacros:
ForEachMacros:
- 'json_object_foreach'
- 'json_object_foreach_safe'
- 'json_array_foreach'

View File

@@ -24,6 +24,10 @@ inputs:
description: 'Developer ID for installer package codesigning (macOS only)'
required: false
default: ''
codesignTeam:
description: 'Developer team for codesigning (macOS only)'
required: false
default: ''
codesignUser:
description: 'Apple ID username for notarization (macOS only)'
required: false
@@ -50,6 +54,7 @@ runs:
env:
CODESIGN_IDENT: ${{ inputs.codesignIdent }}
CODESIGN_IDENT_INSTALLER: ${{ inputs.installerIdent }}
CODESIGN_TEAM: ${{ inputs.codesignTeam }}
CODESIGN_IDENT_USER: ${{ inputs.codesignUser }}
CODESIGN_IDENT_PASS: ${{ inputs.codesignPass }}
run: |

View File

@@ -129,6 +129,7 @@ jobs:
codesign: ${{ fromJSON(needs.check-event.outputs.codesign) && fromJSON(steps.codesign.outputs.haveCodesignIdent) }}
codesignIdent: ${{ steps.codesign.outputs.codesignIdent }}
installerIdent: ${{ steps.codesign.outputs.installerIdent }}
codesignTeam: ${{ steps.codesign.outputs.codesignTeam }}
notarize: ${{ fromJSON(needs.check-event.outputs.notarize) && fromJSON(steps.codesign.outputs.haveNotarizationUser) }}
codesignUser: ${{ secrets.MACOS_NOTARIZATION_USERNAME }}
codesignPass: ${{ secrets.MACOS_NOTARIZATION_PASSWORD }}

2
.gitignore vendored
View File

@@ -15,6 +15,8 @@
!CMakePresets.json
!LICENSE
!README.md
!/vendor
!patch_libobs.diff
# Exclude lock files
*.lock.json

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "vendor/curl"]
path = vendor/curl
url = https://github.com/curl/curl.git

View File

@@ -4,8 +4,8 @@ include("${CMAKE_CURRENT_SOURCE_DIR}/cmake/common/bootstrap.cmake" NO_POLICY_SCO
project(${_name} VERSION ${_version})
option(ENABLE_FRONTEND_API "Use obs-frontend-api for UI functionality" OFF)
option(ENABLE_QT "Use Qt functionality" OFF)
option(ENABLE_FRONTEND_API "Use obs-frontend-api for UI functionality" ON)
option(ENABLE_QT "Use Qt functionality" ON)
include(compilerconfig)
include(defaults)
@@ -34,6 +34,25 @@ if(ENABLE_QT)
AUTORCC ON)
endif()
target_sources(${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c)
set(USE_SYSTEM_CURL
OFF
CACHE STRING "Use system cURL")
if(USE_SYSTEM_CURL)
find_package(CURL REQUIRED)
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${CURL_LIBRARIES}")
target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${CURL_INCLUDE_DIRS}")
else()
include(cmake/BuildMyCurl.cmake)
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE libcurl)
endif()
include(cmake/BuildWhispercpp.cmake)
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp)
target_sources(
${CMAKE_PROJECT_NAME}
PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c src/whisper-processing.cpp
src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp)
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})

View File

@@ -23,8 +23,8 @@
"CMAKE_OSX_DEPLOYMENT_TARGET": "11.0",
"CODESIGN_IDENTITY": "$penv{CODESIGN_IDENT}",
"CODESIGN_TEAM": "$penv{CODESIGN_TEAM}",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{
@@ -53,8 +53,8 @@
"cacheVariables": {
"QT_VERSION": "6",
"CMAKE_SYSTEM_VERSION": "10.0.18363.657",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{
@@ -81,8 +81,8 @@
"cacheVariables": {
"QT_VERSION": "6",
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{
@@ -110,8 +110,8 @@
"cacheVariables": {
"QT_VERSION": "6",
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{

View File

@@ -41,14 +41,14 @@
},
"platformConfig": {
"macos": {
"bundleId": "com.example.obs-plugintemplate"
"bundleId": "com.royshilkrot.obs-localvocal"
}
},
"name": "obs-plugintemplate",
"version": "1.0.0",
"author": "Your Name Here",
"website": "https://example.com",
"email": "me@example.com",
"name": "obs-localvocal",
"version": "0.0.1",
"author": "Roy Shilkrot",
"website": "https://github.com/royshil/obs-localvocal",
"email": "roy.shil@gmail.com",
"uuids": {
"macosPackage": "00000000-0000-0000-0000-000000000000",
"macosInstaller": "00000000-0000-0000-0000-000000000000",

29
cmake/BuildMyCurl.cmake Normal file
View File

@@ -0,0 +1,29 @@
set(LIBCURL_SOURCE_DIR ${CMAKE_SOURCE_DIR}/vendor/curl)
find_package(Git QUIET)
execute_process(
COMMAND ${GIT_EXECUTABLE} checkout curl-8_2_0
WORKING_DIRECTORY ${LIBCURL_SOURCE_DIR}
RESULT_VARIABLE GIT_SUBMOD_RESULT)
if(OS_MACOS)
set(CURL_USE_OPENSSL OFF)
set(CURL_USE_SECTRANSP ON)
elseif(OS_WINDOWS)
set(CURL_USE_OPENSSL OFF)
set(CURL_USE_SCHANNEL ON)
elseif(OS_LINUX)
add_compile_options(-fPIC)
set(CURL_USE_OPENSSL ON)
endif()
set(BUILD_CURL_EXE OFF)
set(BUILD_SHARED_LIBS OFF)
set(HTTP_ONLY OFF)
set(CURL_USE_LIBSSH2 OFF)
add_subdirectory(${LIBCURL_SOURCE_DIR} EXCLUDE_FROM_ALL)
if(OS_MACOS)
target_compile_options(
libcurl PRIVATE -Wno-error=ambiguous-macro -Wno-error=deprecated-declarations -Wno-error=unreachable-code
-Wno-error=unused-parameter -Wno-error=unused-variable)
endif()
include_directories(SYSTEM ${LIBCURL_SOURCE_DIR}/include)

View File

@@ -0,0 +1,51 @@
include(ExternalProject)
set(CMAKE_OSX_ARCHITECTURES_ "arm64$<SEMICOLON>x86_64")
if(${CMAKE_BUILD_TYPE} STREQUAL Release OR ${CMAKE_BUILD_TYPE} STREQUAL RelWithDebInfo)
set(Whispercpp_BUILD_TYPE Release)
else()
set(Whispercpp_BUILD_TYPE Debug)
endif()
# On linux add the `-fPIC` flag to the compiler
if(UNIX AND NOT APPLE)
set(WHISPER_EXTRA_CXX_FLAGS "-fPIC")
endif()
ExternalProject_Add(
Whispercpp_Build
DOWNLOAD_EXTRACT_TIMESTAMP true
GIT_REPOSITORY https://github.com/ggerganov/whisper.cpp.git
GIT_TAG 7b374c9ac9b9861bb737eec060e4dfa29d229259
BUILD_COMMAND ${CMAKE_COMMAND} --build <BINARY_DIR> --config ${Whispercpp_BUILD_TYPE}
BUILD_BYPRODUCTS <INSTALL_DIR>/lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX}
CMAKE_GENERATOR ${CMAKE_GENERATOR}
INSTALL_COMMAND ${CMAKE_COMMAND} --install <BINARY_DIR> --config ${Whispercpp_BUILD_TYPE}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
-DCMAKE_BUILD_TYPE=${Whispercpp_BUILD_TYPE}
-DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}
-DCMAKE_OSX_DEPLOYMENT_TARGET=10.13
-DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES_}
-DCMAKE_CXX_FLAGS=${WHISPER_EXTRA_CXX_FLAGS}
-DCMAKE_C_FLAGS=${WHISPER_EXTRA_CXX_FLAGS}
-DBUILD_SHARED_LIBS=OFF
-DWHISPER_BUILD_TESTS=OFF
-DWHISPER_BUILD_EXAMPLES=OFF
-DWHISPER_OPENBLAS=ON)
ExternalProject_Get_Property(Whispercpp_Build INSTALL_DIR)
add_library(Whispercpp::Whisper STATIC IMPORTED)
set_target_properties(
Whispercpp::Whisper
PROPERTIES IMPORTED_LOCATION
${INSTALL_DIR}/lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX})
add_library(Whispercpp INTERFACE)
add_dependencies(Whispercpp Whispercpp_Build)
target_link_libraries(Whispercpp INTERFACE Whispercpp::Whisper)
set_target_properties(Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include)
if(APPLE)
target_link_libraries(Whispercpp INTERFACE "-framework Accelerate")
endif(APPLE)

View File

@@ -73,6 +73,13 @@ function(_setup_obs_studio)
set(_cmake_version "3.0.0")
endif()
message(STATUS "Patch libobs")
execute_process(
COMMAND patch --forward "libobs/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/patch_libobs.diff"
RESULT_VARIABLE _process_result
WORKING_DIRECTORY "${dependencies_dir}/${_obs_destination}")
message(STATUS "Patch - done")
message(STATUS "Configure ${label} (${arch})")
execute_process(
COMMAND

View File

@@ -0,0 +1 @@
transcription_filterAudioFilter=LocalVocal Transcription

Binary file not shown.

20
patch_libobs.diff Normal file
View File

@@ -0,0 +1,20 @@
diff --git a/libobs/CMakeLists.txt b/libobs/CMakeLists.txt
index d2e2671..5a9242a 100644
--- a/libobs/CMakeLists.txt
+++ b/libobs/CMakeLists.txt
@@ -263,6 +263,7 @@ set(public_headers
graphics/vec3.h
graphics/vec4.h
media-io/audio-io.h
+ media-io/audio-resampler.h
media-io/frame-rate.h
media-io/media-io-defs.h
media-io/video-io.h
@@ -287,6 +288,7 @@ set(public_headers
util/base.h
util/bmem.h
util/c99defs.h
+ util/circlebuf.h
util/darray.h
util/profiler.h
util/sse-intrin.h

View File

@@ -0,0 +1,180 @@
#include "model-downloader-ui.h"
#include "plugin-support.h"
#include <obs-module.h>
const std::string MODEL_BASE_PATH = "https://huggingface.co/ggerganov/whisper.cpp";
const std::string MODEL_PREFIX = "resolve/main/";
size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream)
{
size_t written = fwrite(ptr, size, nmemb, stream);
return written;
}
ModelDownloader::ModelDownloader(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback_, QWidget *parent)
: QDialog(parent), download_finished_callback(download_finished_callback_)
{
this->setWindowTitle("Downloading model...");
this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint);
this->setFixedSize(300, 100);
this->layout = new QVBoxLayout(this);
// Add a label for the model name
QLabel *model_name_label = new QLabel(this);
model_name_label->setText(QString::fromStdString(model_name));
model_name_label->setAlignment(Qt::AlignCenter);
this->layout->addWidget(model_name_label);
this->progress_bar = new QProgressBar(this);
this->progress_bar->setRange(0, 100);
this->progress_bar->setValue(0);
this->progress_bar->setAlignment(Qt::AlignCenter);
// Show progress as a percentage
this->progress_bar->setFormat("%p%");
this->layout->addWidget(this->progress_bar);
this->download_thread = new QThread();
this->download_worker = new ModelDownloadWorker(model_name);
this->download_worker->moveToThread(this->download_thread);
connect(this->download_thread, &QThread::started, this->download_worker,
&ModelDownloadWorker::download_model);
connect(this->download_worker, &ModelDownloadWorker::download_progress, this,
&ModelDownloader::update_progress);
connect(this->download_worker, &ModelDownloadWorker::download_finished, this,
&ModelDownloader::download_finished);
connect(this->download_worker, &ModelDownloadWorker::download_finished,
this->download_thread, &QThread::quit);
connect(this->download_worker, &ModelDownloadWorker::download_finished,
this->download_worker, &ModelDownloadWorker::deleteLater);
connect(this->download_worker, &ModelDownloadWorker::download_error, this,
&ModelDownloader::show_error);
connect(this->download_thread, &QThread::finished, this->download_thread,
&QThread::deleteLater);
this->download_thread->start();
}
void ModelDownloader::update_progress(int progress)
{
this->progress_bar->setValue(progress);
}
void ModelDownloader::download_finished()
{
this->setWindowTitle("Download finished!");
this->progress_bar->setValue(100);
this->progress_bar->setFormat("Download finished!");
this->progress_bar->setAlignment(Qt::AlignCenter);
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #05B8CC; }");
// Add a button to close the dialog
QPushButton *close_button = new QPushButton("Close", this);
this->layout->addWidget(close_button);
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
// Call the callback
this->download_finished_callback(0);
}
void ModelDownloader::show_error(const std::string &reason)
{
this->setWindowTitle("Download failed!");
this->progress_bar->setFormat("Download failed!");
this->progress_bar->setAlignment(Qt::AlignCenter);
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }");
// Add a label to show the error
QLabel *error_label = new QLabel(this);
error_label->setText(QString::fromStdString(reason));
error_label->setAlignment(Qt::AlignCenter);
// Color red
error_label->setStyleSheet("QLabel { color : red; }");
this->layout->addWidget(error_label);
// Add a button to close the dialog
QPushButton *close_button = new QPushButton("Close", this);
this->layout->addWidget(close_button);
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
this->download_finished_callback(1);
}
ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_)
{
this->model_name = model_name_;
}
void ModelDownloadWorker::download_model()
{
std::string module_data_dir = obs_get_module_data_path(obs_current_module());
// join the directory and the filename using the platform-specific separator
std::string model_save_path = module_data_dir + "/" + this->model_name;
obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str());
// extract filename from path in this->modle_name
const std::string model_filename =
this->model_name.substr(this->model_name.find_last_of("/\\") + 1);
std::string model_url = MODEL_BASE_PATH + "/" + MODEL_PREFIX + model_filename;
obs_log(LOG_INFO, "Model URL: %s", model_url.c_str());
CURL *curl = curl_easy_init();
if (curl) {
FILE *fp = fopen(model_save_path.c_str(), "wb");
if (fp == nullptr) {
obs_log(LOG_ERROR, "Failed to open file %s.", model_save_path.c_str());
emit download_error("Failed to open file.");
return;
}
curl_easy_setopt(curl, CURLOPT_URL, model_url.c_str());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION,
ModelDownloadWorker::progress_callback);
curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this);
// Follow redirects
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
CURLcode res = curl_easy_perform(curl);
if (res != CURLE_OK) {
obs_log(LOG_ERROR, "Failed to download model %s.",
this->model_name.c_str());
emit download_error("Failed to download model.");
}
curl_easy_cleanup(curl);
fclose(fp);
} else {
obs_log(LOG_ERROR, "Failed to initialize curl.");
emit download_error("Failed to initialize curl.");
}
emit download_finished();
}
int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
curl_off_t, curl_off_t)
{
if (dltotal == 0) {
return 0; // Unknown progress
}
ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp;
if (worker == nullptr) {
obs_log(LOG_ERROR, "Worker is null.");
return 1;
}
int progress = (int)(dlnow * 100l / dltotal);
emit worker->download_progress(progress);
return 0;
}
ModelDownloader::~ModelDownloader()
{
this->download_thread->quit();
this->download_thread->wait();
delete this->download_thread;
delete this->download_worker;
}
ModelDownloadWorker::~ModelDownloadWorker()
{
// Do nothing
}

View File

@@ -0,0 +1,54 @@
#ifndef MODEL_DOWNLOADER_UI_H
#define MODEL_DOWNLOADER_UI_H
#include <QtWidgets>
#include <QThread>
#include <string>
#include <functional>
#include <curl/curl.h>
class ModelDownloadWorker : public QObject {
Q_OBJECT
public:
ModelDownloadWorker(const std::string &model_name);
~ModelDownloadWorker();
public slots:
void download_model();
signals:
void download_progress(int progress);
void download_finished();
void download_error(const std::string &reason);
private:
static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
curl_off_t ultotal, curl_off_t ulnow);
std::string model_name;
};
class ModelDownloader : public QDialog {
Q_OBJECT
public:
ModelDownloader(const std::string &model_name,
std::function<void(int download_status)> download_finished_callback,
QWidget *parent = nullptr);
~ModelDownloader();
public slots:
void update_progress(int progress);
void download_finished();
void show_error(const std::string &reason);
private:
QVBoxLayout *layout;
QProgressBar *progress_bar;
QThread *download_thread;
ModelDownloadWorker *download_worker;
// Callback for when the download is finished
std::function<void(int download_status)> download_finished_callback;
};
#endif // MODEL_DOWNLOADER_UI_H

View File

@@ -0,0 +1,42 @@
#include "model-downloader.h"
#include "plugin-support.h"
#include "model-downloader-ui.h"
#include <obs-module.h>
#include <obs-frontend-api.h>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <curl/curl.h>
bool check_if_model_exists(const std::string &model_name)
{
obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str());
char *model_file_path = obs_module_file(model_name.c_str());
obs_log(LOG_INFO, "Model file path: %s", model_file_path);
if (model_file_path == nullptr) {
obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str());
return false;
}
if (!std::filesystem::exists(model_file_path)) {
obs_log(LOG_INFO, "Model %s does not exist.", model_file_path);
bfree(model_file_path);
return false;
}
bfree(model_file_path);
return true;
}
void download_model_with_ui_dialog(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback)
{
// Start the model downloader UI
ModelDownloader *model_downloader = new ModelDownloader(
model_name, download_finished_callback, (QWidget *)obs_frontend_get_main_window());
model_downloader->show();
}

View File

@@ -0,0 +1,14 @@
#ifndef MODEL_DOWNLOADER_H
#define MODEL_DOWNLOADER_H
#include <string>
#include <functional>
bool check_if_model_exists(const std::string &model_name);
// Start the model downloader UI dialog with a callback for when the download is finished
void download_model_with_ui_dialog(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback);
#endif // MODEL_DOWNLOADER_H

View File

@@ -22,10 +22,17 @@ with this program. If not, see <https://www.gnu.org/licenses/>
OBS_DECLARE_MODULE()
OBS_MODULE_USE_DEFAULT_LOCALE(PLUGIN_NAME, "en-US")
MODULE_EXPORT const char *obs_module_description(void)
{
return obs_module_text("LocalVocalPlugin");
}
extern struct obs_source_info transcription_filter_info;
bool obs_module_load(void)
{
obs_log(LOG_INFO, "plugin loaded successfully (version %s)",
PLUGIN_VERSION);
obs_register_source(&transcription_filter_info);
blog(LOG_INFO, "plugin loaded successfully (version %s)", PLUGIN_VERSION);
return true;
}

View File

@@ -21,6 +21,8 @@ with this program. If not, see <https://www.gnu.org/licenses/>
const char *PLUGIN_NAME = "@CMAKE_PROJECT_NAME@";
const char *PLUGIN_VERSION = "@CMAKE_PROJECT_VERSION@";
extern void blogva(int log_level, const char *format, va_list args);
void obs_log(int log_level, const char *format, ...)
{
size_t length = 4 + strlen(PLUGIN_NAME) + strlen(format);

View File

@@ -31,7 +31,6 @@ extern const char *PLUGIN_NAME;
extern const char *PLUGIN_VERSION;
void obs_log(int log_level, const char *format, ...);
extern void blogva(int log_level, const char *format, va_list args);
#ifdef __cplusplus
}

View File

@@ -0,0 +1,76 @@
#ifndef TRANSCRIPTION_FILTER_DATA_H
#define TRANSCRIPTION_FILTER_DATA_H
#include <obs.h>
#include <util/circlebuf.h>
#include <util/darray.h>
#include <media-io/audio-resampler.h>
#include <whisper.h>
#include <thread>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <functional>
#include <string>
#define MAX_PREPROC_CHANNELS 2
#define MT_ obs_module_text
struct transcription_filter_data {
obs_source_t *context; // obs input source
size_t channels; // number of channels
uint32_t sample_rate; // input sample rate
// How many input frames (in input sample rate) are needed for the next whisper frame
size_t frames;
// How many ms/frames are needed to overlap with the next whisper frame
size_t overlap_frames;
size_t overlap_ms;
// How many frames were processed in the last whisper frame (this is dynamic)
size_t last_num_frames;
/* PCM buffers */
float *copy_buffers[MAX_PREPROC_CHANNELS];
struct circlebuf info_buffer;
struct circlebuf input_buffers[MAX_PREPROC_CHANNELS];
/* Resampler */
audio_resampler_t *resampler;
/* whisper */
std::string whisper_model_path = "models/ggml-tiny.en.bin";
struct whisper_context *whisper_context;
whisper_full_params whisper_params;
float filler_p_threshold;
bool do_silence;
bool vad_enabled;
int log_level;
bool log_words;
bool active;
// Text source to output the subtitles
obs_weak_source_t *text_source;
char *text_source_name;
std::unique_ptr<std::mutex> text_source_mutex;
// Callback to set the text in the output text source (subtitles)
std::function<void(const std::string &str)> setTextCallback;
// Use std for thread and mutex
std::thread whisper_thread;
std::unique_ptr<std::mutex> whisper_buf_mutex;
std::unique_ptr<std::mutex> whisper_ctx_mutex;
std::unique_ptr<std::condition_variable> wshiper_thread_cv;
};
// Audio packet info
struct transcription_filter_audio_info {
uint32_t frames;
uint64_t timestamp;
};
#endif /* TRANSCRIPTION_FILTER_DATA_H */

View File

@@ -0,0 +1,16 @@
#include "transcription-filter.h"
struct obs_source_info transcription_filter_info = {
.id = "transcription_filter_audio_filter",
.type = OBS_SOURCE_TYPE_FILTER,
.output_flags = OBS_SOURCE_AUDIO,
.get_name = transcription_filter_name,
.create = transcription_filter_create,
.destroy = transcription_filter_destroy,
.get_defaults = transcription_filter_defaults,
.get_properties = transcription_filter_properties,
.update = transcription_filter_update,
.activate = transcription_filter_activate,
.deactivate = transcription_filter_deactivate,
.filter_audio = transcription_filter_filter_audio,
};

View File

@@ -0,0 +1,536 @@
#include <obs-module.h>
#include "plugin-support.h"
#include "transcription-filter.h"
#include "transcription-filter-data.h"
#include "whisper-processing.h"
#include "whisper-language.h"
#include "model-utils/model-downloader.h"
#include <algorithm>
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
{
switch (channels) {
case 0:
return SPEAKERS_UNKNOWN;
case 1:
return SPEAKERS_MONO;
case 2:
return SPEAKERS_STEREO;
case 3:
return SPEAKERS_2POINT1;
case 4:
return SPEAKERS_4POINT0;
case 5:
return SPEAKERS_4POINT1;
case 6:
return SPEAKERS_5POINT1;
case 8:
return SPEAKERS_7POINT1;
default:
return SPEAKERS_UNKNOWN;
}
}
bool add_sources_to_list(void *list_property, obs_source_t *source)
{
auto source_id = obs_source_get_id(source);
if (strcmp(source_id, "text_ft2_source_v2") != 0 &&
strcmp(source_id, "text_gdiplus_v2") != 0) {
return true;
}
obs_property_t *sources = (obs_property_t *)list_property;
const char *name = obs_source_get_name(source);
obs_property_list_add_string(sources, name, name);
return true;
}
struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_audio_data *audio)
{
if (!audio) {
return nullptr;
}
if (data == nullptr) {
return audio;
}
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
if (!gf->active) {
return audio;
}
if (gf->whisper_context == nullptr) {
// Whisper not initialized, just pass through
return audio;
}
{
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex); // scoped lock
obs_log(gf->log_level,
"pushing %lu frames to input buffer. current size: %lu (bytes)",
(size_t)(audio->frames), gf->input_buffers[0].size);
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
audio->frames * sizeof(float));
}
// push audio packet info (timestamp/frame count) to info circlebuf
struct transcription_filter_audio_info info = {0};
info.frames = audio->frames; // number of frames in this packet
info.timestamp = audio->timestamp; // timestamp of this packet
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
}
return audio;
}
const char *transcription_filter_name(void *unused)
{
UNUSED_PARAMETER(unused);
return MT_("transcription_filterAudioFilter");
}
void transcription_filter_destroy(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_log(LOG_INFO, "transcription_filter_destroy");
{
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
if (gf->whisper_context != nullptr) {
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
gf->wshiper_thread_cv->notify_all();
}
}
// join the thread
if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join();
}
if (gf->text_source_name) {
bfree(gf->text_source_name);
gf->text_source_name = nullptr;
}
if (gf->text_source) {
obs_weak_source_release(gf->text_source);
gf->text_source = nullptr;
}
if (gf->resampler) {
audio_resampler_destroy(gf->resampler);
}
{
std::lock_guard<std::mutex> lockbuf(*gf->whisper_buf_mutex);
bfree(gf->copy_buffers[0]);
gf->copy_buffers[0] = nullptr;
for (size_t i = 0; i < gf->channels; i++) {
circlebuf_free(&gf->input_buffers[i]);
}
}
circlebuf_free(&gf->info_buffer);
bfree(gf);
}
void acquire_weak_text_source_ref(struct transcription_filter_data *gf)
{
if (!gf->text_source_name) {
obs_log(LOG_ERROR, "text_source_name is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
// acquire a weak ref to the new text source
obs_source_t *source = obs_get_source_by_name(gf->text_source_name);
if (source) {
gf->text_source = obs_source_get_weak_source(source);
obs_source_release(source);
if (!gf->text_source) {
obs_log(LOG_ERROR, "failed to get weak source for text source %s",
gf->text_source_name);
}
} else {
obs_log(LOG_ERROR, "text source '%s' not found", gf->text_source_name);
}
}
void transcription_filter_update(void *data, obs_data_t *s)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
gf->log_level = (int)obs_data_get_int(s, "log_level");
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
gf->log_words = obs_data_get_bool(s, "log_words");
// update the text source
const char *text_source_name = obs_data_get_string(s, "subtitle_sources");
obs_weak_source_t *old_weak_text_source = NULL;
if (strcmp(text_source_name, "none") == 0 || strcmp(text_source_name, "(null)") == 0) {
// new selected text source is not valid, release the old one
if (gf->text_source) {
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
old_weak_text_source = gf->text_source;
gf->text_source = nullptr;
}
if (gf->text_source_name) {
bfree(gf->text_source_name);
gf->text_source_name = nullptr;
}
} else {
// new selected text source is valid, check if it's different from the old one
if (gf->text_source_name == nullptr ||
strcmp(text_source_name, gf->text_source_name) != 0) {
// new text source is different from the old one, release the old one
if (gf->text_source) {
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
old_weak_text_source = gf->text_source;
gf->text_source = nullptr;
}
gf->text_source_name = bstrdup(text_source_name);
}
}
if (old_weak_text_source) {
obs_weak_source_release(old_weak_text_source);
}
const char *new_model_path = obs_data_get_string(s, "whisper_model_path");
if (strcmp(new_model_path, gf->whisper_model_path.c_str()) != 0) {
// model path changed, reload the model
obs_log(LOG_INFO, "model path changed, reloading model");
if (gf->whisper_context != nullptr) {
// acquire the mutex before freeing the context
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
gf->wshiper_thread_cv->notify_all();
}
if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join();
}
gf->whisper_model_path = bstrdup(new_model_path);
// check if the model exists, if not, download it
if (!check_if_model_exists(gf->whisper_model_path)) {
obs_log(LOG_ERROR, "Whisper model does not exist");
download_model_with_ui_dialog(
gf->whisper_model_path, [gf](int download_status) {
if (download_status == 0) {
obs_log(LOG_INFO, "Model download complete");
gf->whisper_context = init_whisper_context(
gf->whisper_model_path);
gf->whisper_thread = std::thread(whisper_loop, gf);
} else {
obs_log(LOG_ERROR, "Model download failed");
}
});
} else {
// Model exists, just load it
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
gf->whisper_thread = std::thread(whisper_loop, gf);
}
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
gf->whisper_params = whisper_full_default_params(
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
gf->whisper_params.duration_ms = BUFFER_SIZE_MSEC;
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
gf->whisper_params.translate = obs_data_get_bool(s, "translate");
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
gf->whisper_params.suppress_non_speech_tokens =
obs_data_get_bool(s, "suppress_non_speech_tokens");
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");
}
void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
{
struct transcription_filter_data *gf = static_cast<struct transcription_filter_data *>(
bmalloc(sizeof(struct transcription_filter_data)));
// Get the number of channels for the input source
gf->channels = audio_output_get_channels(obs_get_audio());
gf->sample_rate = audio_output_get_sample_rate(obs_get_audio());
gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)BUFFER_SIZE_MSEC));
gf->last_num_frames = 0;
for (size_t i = 0; i < MAX_AUDIO_CHANNELS; i++) {
circlebuf_init(&gf->input_buffers[i]);
}
circlebuf_init(&gf->info_buffer);
// allocate copy buffers
gf->copy_buffers[0] =
static_cast<float *>(bzalloc(gf->channels * gf->frames * sizeof(float)));
for (size_t c = 1; c < gf->channels; c++) { // set the channel pointers
gf->copy_buffers[c] = gf->copy_buffers[0] + c * gf->frames;
}
gf->context = filter;
gf->whisper_model_path = obs_data_get_string(settings, "whisper_model_path");
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
if (gf->whisper_context == nullptr) {
obs_log(LOG_ERROR, "Failed to load whisper model");
return nullptr;
}
gf->overlap_ms = OVERLAP_SIZE_MSEC;
gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
obs_log(LOG_INFO, "transcription_filter filter: channels %d, frames %d, sample_rate %d",
(int)gf->channels, (int)gf->frames, gf->sample_rate);
struct resample_info src, dst;
src.samples_per_sec = gf->sample_rate;
src.format = AUDIO_FORMAT_FLOAT_PLANAR;
src.speakers = convert_speaker_layout((uint8_t)gf->channels);
dst.samples_per_sec = WHISPER_SAMPLE_RATE;
dst.format = AUDIO_FORMAT_FLOAT_PLANAR;
dst.speakers = convert_speaker_layout((uint8_t)1);
gf->resampler = audio_resampler_create(&dst, &src);
gf->active = true;
gf->whisper_buf_mutex = std::unique_ptr<std::mutex>(new std::mutex());
gf->whisper_ctx_mutex = std::unique_ptr<std::mutex>(new std::mutex());
gf->wshiper_thread_cv =
std::unique_ptr<std::condition_variable>(new std::condition_variable());
gf->text_source_mutex = std::unique_ptr<std::mutex>(new std::mutex());
// set the callback to set the text in the output text source (subtitles)
gf->setTextCallback = [gf](const std::string &str) {
if (!gf->text_source) {
// attempt to acquire a weak ref to the text source if it's yet available
acquire_weak_text_source_ref(gf);
}
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
obs_weak_source_t *text_source = gf->text_source;
if (!text_source) {
obs_log(LOG_ERROR, "text_source is null");
return;
}
auto target = obs_weak_source_get_source(text_source);
if (!target) {
obs_log(LOG_ERROR, "text_source target is null");
return;
}
auto text_settings = obs_source_get_settings(target);
obs_data_set_string(text_settings, "text", str.c_str());
obs_source_update(target, text_settings);
obs_source_release(target);
};
// get the settings updated on the filter data struct
transcription_filter_update(gf, settings);
// start the thread
gf->whisper_thread = std::thread(whisper_loop, gf);
return gf;
}
void transcription_filter_activate(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_log(LOG_INFO, "transcription_filter filter activated");
gf->active = true;
}
void transcription_filter_deactivate(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_log(LOG_INFO, "transcription_filter filter deactivated");
gf->active = false;
}
void transcription_filter_defaults(obs_data_t *s)
{
obs_data_set_default_bool(s, "vad_enabled", true);
obs_data_set_default_int(s, "log_level", LOG_DEBUG);
obs_data_set_default_bool(s, "log_words", true);
obs_data_set_default_string(s, "whisper_model_path", "models/ggml-tiny.en.bin");
obs_data_set_default_string(s, "whisper_language_select", "en");
obs_data_set_default_string(s, "subtitle_sources", "none");
// Whisper parameters
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
obs_data_set_default_string(s, "initial_prompt", "");
obs_data_set_default_int(s, "n_threads", 4);
obs_data_set_default_int(s, "n_max_text_ctx", 16384);
obs_data_set_default_bool(s, "translate", false);
obs_data_set_default_bool(s, "no_context", true);
obs_data_set_default_bool(s, "single_segment", true);
obs_data_set_default_bool(s, "print_special", false);
obs_data_set_default_bool(s, "print_progress", false);
obs_data_set_default_bool(s, "print_realtime", false);
obs_data_set_default_bool(s, "print_timestamps", false);
obs_data_set_default_bool(s, "token_timestamps", false);
obs_data_set_default_double(s, "thold_pt", 0.01);
obs_data_set_default_double(s, "thold_ptsum", 0.01);
obs_data_set_default_int(s, "max_len", 0);
obs_data_set_default_bool(s, "split_on_word", false);
obs_data_set_default_int(s, "max_tokens", 32);
obs_data_set_default_bool(s, "speed_up", false);
obs_data_set_default_bool(s, "suppress_blank", false);
obs_data_set_default_bool(s, "suppress_non_speech_tokens", true);
obs_data_set_default_double(s, "temperature", 0.5);
obs_data_set_default_double(s, "max_initial_ts", 1.0);
obs_data_set_default_double(s, "length_penalty", -1.0);
}
obs_properties_t *transcription_filter_properties(void *data)
{
obs_properties_t *ppts = obs_properties_create();
obs_properties_add_bool(ppts, "vad_enabled", "VAD Enabled");
obs_property_t *list = obs_properties_add_list(ppts, "log_level", "Log level",
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
obs_property_list_add_int(list, "DEBUG", LOG_DEBUG);
obs_property_list_add_int(list, "INFO", LOG_INFO);
obs_property_list_add_int(list, "WARNING", LOG_WARNING);
obs_properties_add_bool(ppts, "log_words", "Log output words");
obs_property_t *sources = obs_properties_add_list(ppts, "subtitle_sources",
"subtitle_sources", OBS_COMBO_TYPE_LIST,
OBS_COMBO_FORMAT_STRING);
obs_enum_sources(add_sources_to_list, sources);
// Add a list of available whisper models to download
obs_property_t *whisper_models_list =
obs_properties_add_list(ppts, "whisper_model_path", "Whisper Model",
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
obs_property_list_add_string(whisper_models_list, "Tiny (Eng) 75Mb",
"models/ggml-tiny.en.bin");
obs_property_list_add_string(whisper_models_list, "Tiny 75Mb", "models/ggml-tiny.bin");
obs_property_list_add_string(whisper_models_list, "Base (Eng) 142Mb",
"models/ggml-base.en.bin");
obs_property_list_add_string(whisper_models_list, "Base 142Mb", "models/ggml-base.bin");
obs_property_list_add_string(whisper_models_list, "Small (Eng) 466Mb",
"models/ggml-small.en.bin");
obs_property_list_add_string(whisper_models_list, "Small 466Mb", "models/ggml-small.bin");
obs_properties_t *whisper_params_group = obs_properties_create();
obs_properties_add_group(ppts, "whisper_params_group", "Whisper Parameters",
OBS_GROUP_NORMAL, whisper_params_group);
// Add language selector
obs_property_t *whisper_language_select_list =
obs_properties_add_list(whisper_params_group, "whisper_language_select", "Language",
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
// sort the languages by flipping the map
std::map<std::string, std::string> whisper_available_lang_flip;
for (auto const &pair : whisper_available_lang) {
whisper_available_lang_flip[pair.second] = pair.first;
}
// iterate over all available languages and add them to the list
for (auto const &pair : whisper_available_lang_flip) {
// Capitalize the language name
std::string language_name = pair.first;
language_name[0] = (char)toupper(language_name[0]);
obs_property_list_add_string(whisper_language_select_list, language_name.c_str(),
pair.second.c_str());
}
obs_property_t *whisper_sampling_method_list = obs_properties_add_list(
whisper_params_group, "whisper_sampling_method", "whisper_sampling_method",
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
obs_property_list_add_int(whisper_sampling_method_list, "Beam search",
WHISPER_SAMPLING_BEAM_SEARCH);
obs_property_list_add_int(whisper_sampling_method_list, "Greedy", WHISPER_SAMPLING_GREEDY);
// int n_threads;
obs_properties_add_int_slider(whisper_params_group, "n_threads", "n_threads", 1, 8, 1);
// int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
obs_properties_add_int_slider(whisper_params_group, "n_max_text_ctx", "n_max_text_ctx", 0,
16384, 100);
// int offset_ms; // start offset in ms
// int duration_ms; // audio duration to process in ms
// bool translate;
obs_properties_add_bool(whisper_params_group, "translate", "translate");
// bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
obs_properties_add_bool(whisper_params_group, "no_context", "no_context");
// bool single_segment; // force single segment output (useful for streaming)
obs_properties_add_bool(whisper_params_group, "single_segment", "single_segment");
// bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
obs_properties_add_bool(whisper_params_group, "print_special", "print_special");
// bool print_progress; // print progress information
obs_properties_add_bool(whisper_params_group, "print_progress", "print_progress");
// bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
obs_properties_add_bool(whisper_params_group, "print_realtime", "print_realtime");
// bool print_timestamps; // print timestamps for each text segment when printing realtime
obs_properties_add_bool(whisper_params_group, "print_timestamps", "print_timestamps");
// bool token_timestamps; // enable token-level timestamps
obs_properties_add_bool(whisper_params_group, "token_timestamps", "token_timestamps");
// float thold_pt; // timestamp token probability threshold (~0.01)
obs_properties_add_float_slider(whisper_params_group, "thold_pt", "thold_pt", 0.0f, 1.0f,
0.05f);
// float thold_ptsum; // timestamp token sum probability threshold (~0.01)
obs_properties_add_float_slider(whisper_params_group, "thold_ptsum", "thold_ptsum", 0.0f,
1.0f, 0.05f);
// int max_len; // max segment length in characters
obs_properties_add_int_slider(whisper_params_group, "max_len", "max_len", 0, 100, 1);
// bool split_on_word; // split on word rather than on token (when used with max_len)
obs_properties_add_bool(whisper_params_group, "split_on_word", "split_on_word");
// int max_tokens; // max tokens per segment (0 = no limit)
obs_properties_add_int_slider(whisper_params_group, "max_tokens", "max_tokens", 0, 100, 1);
// bool speed_up; // speed-up the audio by 2x using Phase Vocoder
obs_properties_add_bool(whisper_params_group, "speed_up", "speed_up");
// const char * initial_prompt;
obs_properties_add_text(whisper_params_group, "initial_prompt", "initial_prompt",
OBS_TEXT_DEFAULT);
// bool suppress_blank
obs_properties_add_bool(whisper_params_group, "suppress_blank", "suppress_blank");
// bool suppress_non_speech_tokens
obs_properties_add_bool(whisper_params_group, "suppress_non_speech_tokens",
"suppress_non_speech_tokens");
// float temperature
obs_properties_add_float_slider(whisper_params_group, "temperature", "temperature", 0.0f,
1.0f, 0.05f);
// float max_initial_ts
obs_properties_add_float_slider(whisper_params_group, "max_initial_ts", "max_initial_ts",
0.0f, 1.0f, 0.05f);
// float length_penalty
obs_properties_add_float_slider(whisper_params_group, "length_penalty", "length_penalty",
-1.0f, 1.0f, 0.1f);
UNUSED_PARAMETER(data);
return ppts;
}

View File

@@ -0,0 +1,19 @@
#include <obs-module.h>
#ifdef __cplusplus
extern "C" {
#endif
void transcription_filter_activate(void *data);
void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter);
void transcription_filter_update(void *data, obs_data_t *s);
void transcription_filter_destroy(void *data);
const char *transcription_filter_name(void *unused);
struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_audio_data *audio);
void transcription_filter_deactivate(void *data);
void transcription_filter_defaults(obs_data_t *s);
obs_properties_t *transcription_filter_properties(void *data);
#ifdef __cplusplus
}
#endif

410
src/whisper-language.h Normal file
View File

@@ -0,0 +1,410 @@
#ifndef WHISPER_LANGUAGE_H
#define WHISPER_LANGUAGE_H
#include <map>
#include <string>
static const std::map<std::string, std::string> whisper_available_lang = {
{
"auto",
"auto",
},
{
"en",
"english",
},
{
"zh",
"chinese",
},
{
"de",
"german",
},
{
"es",
"spanish",
},
{
"ru",
"russian",
},
{
"ko",
"korean",
},
{
"fr",
"french",
},
{
"ja",
"japanese",
},
{
"pt",
"portuguese",
},
{
"tr",
"turkish",
},
{
"pl",
"polish",
},
{
"ca",
"catalan",
},
{
"nl",
"dutch",
},
{
"ar",
"arabic",
},
{
"sv",
"swedish",
},
{
"it",
"italian",
},
{
"id",
"indonesian",
},
{
"hi",
"hindi",
},
{
"fi",
"finnish",
},
{
"vi",
"vietnamese",
},
{
"he",
"hebrew",
},
{
"uk",
"ukrainian",
},
{
"el",
"greek",
},
{
"ms",
"malay",
},
{
"cs",
"czech",
},
{
"ro",
"romanian",
},
{
"da",
"danish",
},
{
"hu",
"hungarian",
},
{
"ta",
"tamil",
},
{
"no",
"norwegian",
},
{
"th",
"thai",
},
{
"ur",
"urdu",
},
{
"hr",
"croatian",
},
{
"bg",
"bulgarian",
},
{
"lt",
"lithuanian",
},
{
"la",
"latin",
},
{
"mi",
"maori",
},
{
"ml",
"malayalam",
},
{
"cy",
"welsh",
},
{
"sk",
"slovak",
},
{
"te",
"telugu",
},
{
"fa",
"persian",
},
{
"lv",
"latvian",
},
{
"bn",
"bengali",
},
{
"sr",
"serbian",
},
{
"az",
"azerbaijani",
},
{
"sl",
"slovenian",
},
{
"kn",
"kannada",
},
{
"et",
"estonian",
},
{
"mk",
"macedonian",
},
{
"br",
"breton",
},
{
"eu",
"basque",
},
{
"is",
"icelandic",
},
{
"hy",
"armenian",
},
{
"ne",
"nepali",
},
{
"mn",
"mongolian",
},
{
"bs",
"bosnian",
},
{
"kk",
"kazakh",
},
{
"sq",
"albanian",
},
{
"sw",
"swahili",
},
{
"gl",
"galician",
},
{
"mr",
"marathi",
},
{
"pa",
"punjabi",
},
{
"si",
"sinhala",
},
{
"km",
"khmer",
},
{
"sn",
"shona",
},
{
"yo",
"yoruba",
},
{
"so",
"somali",
},
{
"af",
"afrikaans",
},
{
"oc",
"occitan",
},
{
"ka",
"georgian",
},
{
"be",
"belarusian",
},
{
"tg",
"tajik",
},
{
"sd",
"sindhi",
},
{
"gu",
"gujarati",
},
{
"am",
"amharic",
},
{
"yi",
"yiddish",
},
{
"lo",
"lao",
},
{
"uz",
"uzbek",
},
{
"fo",
"faroese",
},
{
"ht",
"haitian",
},
{
"ps",
"pashto",
},
{
"tk",
"turkmen",
},
{
"nn",
"nynorsk",
},
{
"mt",
"maltese",
},
{
"sa",
"sanskrit",
},
{
"lb",
"luxembourgish",
},
{
"my",
"myanmar",
},
{
"bo",
"tibetan",
},
{
"tl",
"tagalog",
},
{
"mg",
"malagasy",
},
{
"as",
"assamese",
},
{
"tt",
"tatar",
},
{
"haw",
"hawaiian",
},
{
"ln",
"lingala",
},
{
"ha",
"hausa",
},
{
"ba",
"bashkir",
},
{
"jw",
"javanese",
},
{
"su",
"sundanese",
},
};
#endif // WHISPER_LANGUAGE_H

345
src/whisper-processing.cpp Normal file
View File

@@ -0,0 +1,345 @@
#include <whisper.h>
#include <obs-module.h>
#include "plugin-support.h"
#include "transcription-filter-data.h"
#include "whisper-processing.h"
#include <algorithm>
#include <cctype>
#define VAD_THOLD 0.0001f
#define FREQ_THOLD 100.0f
// Taken from https://github.com/ggerganov/whisper.cpp/blob/master/examples/stream/stream.cpp
std::string to_timestamp(int64_t t)
{
int64_t sec = t / 100;
int64_t msec = t - sec * 100;
int64_t min = sec / 60;
sec = sec - min * 60;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int)min, (int)sec, (int)msec);
return std::string(buf);
}
void high_pass_filter(float *pcmf32, size_t pcm32f_size, float cutoff, uint32_t sample_rate)
{
const float rc = 1.0f / (2.0f * (float)M_PI * cutoff);
const float dt = 1.0f / (float)sample_rate;
const float alpha = dt / (rc + dt);
float y = pcmf32[0];
for (size_t i = 1; i < pcm32f_size; i++) {
y = alpha * (y + pcmf32[i] - pcmf32[i - 1]);
pcmf32[i] = y;
}
}
// VAD (voice activity detection), return true if speech detected
bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float vad_thold,
float freq_thold, bool verbose)
{
const uint64_t n_samples = pcm32f_size;
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, pcm32f_size, freq_thold, sample_rate);
}
float energy_all = 0.0f;
for (uint64_t i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
}
energy_all /= (float)n_samples;
if (verbose) {
blog(LOG_INFO, "%s: energy_all: %f, vad_thold: %f, freq_thold: %f", __func__,
energy_all, vad_thold, freq_thold);
}
if (energy_all < vad_thold) {
return false;
}
return true;
}
struct whisper_context *init_whisper_context(const std::string &model_path)
{
struct whisper_context *ctx = whisper_init_from_file(obs_module_file(model_path.c_str()));
if (ctx == nullptr) {
obs_log(LOG_ERROR, "Failed to load whisper model");
return nullptr;
}
return ctx;
}
enum DetectionResult {
DETECTION_RESULT_UNKNOWN = 0,
DETECTION_RESULT_SILENCE = 1,
DETECTION_RESULT_SPEECH = 2,
};
struct DetectionResultWithText {
DetectionResult result;
std::string text;
};
struct DetectionResultWithText run_whisper_inference(struct transcription_filter_data *gf,
const float *pcm32f_data, size_t pcm32f_size)
{
obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__,
int(pcm32f_size), float(pcm32f_size) / WHISPER_SAMPLE_RATE,
gf->whisper_params.n_threads);
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "whisper context is null");
return {DETECTION_RESULT_UNKNOWN, ""};
}
// run the inference
int whisper_full_result = -1;
try {
whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params,
pcm32f_data, (int)pcm32f_size);
} catch (const std::exception &e) {
obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what());
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
return {DETECTION_RESULT_UNKNOWN, ""};
}
if (whisper_full_result != 0) {
obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result);
return {DETECTION_RESULT_UNKNOWN, ""};
} else {
const int n_segment = 0;
const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment);
const int64_t t0 = whisper_full_get_segment_t0(gf->whisper_context, n_segment);
const int64_t t1 = whisper_full_get_segment_t1(gf->whisper_context, n_segment);
float sentence_p = 0.0f;
const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment);
for (int j = 0; j < n_tokens; ++j) {
sentence_p += whisper_full_get_token_p(gf->whisper_context, n_segment, j);
}
sentence_p /= (float)n_tokens;
// convert text to lowercase
std::string text_lower(text);
std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower);
// trim whitespace (use lambda)
text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
text_lower.end());
if (gf->log_words) {
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
to_timestamp(t1).c_str(), sentence_p, text_lower.c_str());
}
if (text_lower.empty()) {
return {DETECTION_RESULT_SILENCE, ""};
}
return {DETECTION_RESULT_SPEECH, text_lower};
}
}
void process_audio_from_buffer(struct transcription_filter_data *gf)
{
uint32_t num_new_frames_from_infos = 0;
uint64_t start_timestamp = 0;
{
// scoped lock the buffer mutex
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
// We need (gf->frames - gf->overlap_frames) new frames to run inference,
// except for the first segment, where we need the whole gf->frames frames
size_t how_many_frames_needed = gf->frames - gf->overlap_frames;
if (gf->last_num_frames == 0) {
how_many_frames_needed = gf->frames;
}
// pop infos from the info buffer and mark the beginning timestamp from the first
// info as the beginning timestamp of the segment
struct transcription_filter_audio_info info_from_buf = {0};
while (gf->info_buffer.size >= sizeof(struct transcription_filter_audio_info)) {
circlebuf_pop_front(&gf->info_buffer, &info_from_buf,
sizeof(struct transcription_filter_audio_info));
num_new_frames_from_infos += info_from_buf.frames;
if (start_timestamp == 0) {
start_timestamp = info_from_buf.timestamp;
}
obs_log(gf->log_level, "popped %d frames from info buffer, %lu needed",
num_new_frames_from_infos, how_many_frames_needed);
// Check if we're within the needed segment length
if (num_new_frames_from_infos > how_many_frames_needed) {
// too big, push the last info into the buffer's front where it was
num_new_frames_from_infos -= info_from_buf.frames;
circlebuf_push_front(
&gf->info_buffer, &info_from_buf,
sizeof(struct transcription_filter_audio_info));
break;
}
}
/* Pop from input circlebuf */
for (size_t c = 0; c < gf->channels; c++) {
if (gf->last_num_frames > 0) {
// move overlap frames from the end of the last copy_buffers to the beginning
memcpy(gf->copy_buffers[c],
gf->copy_buffers[c] + gf->last_num_frames -
gf->overlap_frames,
gf->overlap_frames * sizeof(float));
// copy new data to the end of copy_buffers[c]
circlebuf_pop_front(&gf->input_buffers[c],
gf->copy_buffers[c] + gf->overlap_frames,
num_new_frames_from_infos * sizeof(float));
} else {
// Very first time, just copy data to copy_buffers[c]
circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c],
num_new_frames_from_infos * sizeof(float));
}
}
obs_log(gf->log_level,
"popped %u frames from input buffer. input_buffer[0] size is %lu",
num_new_frames_from_infos, gf->input_buffers[0].size);
if (gf->last_num_frames > 0) {
gf->last_num_frames = num_new_frames_from_infos + gf->overlap_frames;
} else {
gf->last_num_frames = num_new_frames_from_infos;
}
}
obs_log(gf->log_level, "processing %d frames (%d ms), start timestamp %llu ",
(int)gf->last_num_frames, (int)(gf->last_num_frames * 1000 / gf->sample_rate),
start_timestamp);
// time the audio processing
auto start = std::chrono::high_resolution_clock::now();
// resample to 16kHz
float *output[MAX_PREPROC_CHANNELS];
uint32_t out_frames;
uint64_t ts_offset;
audio_resampler_resample(gf->resampler, (uint8_t **)output, &out_frames, &ts_offset,
(const uint8_t **)gf->copy_buffers, (uint32_t)gf->last_num_frames);
obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels, (int)out_frames,
(float)out_frames / WHISPER_SAMPLE_RATE * 1000.0f);
bool skipped_inference = false;
if (gf->vad_enabled) {
skipped_inference = !::vad_simple(output[0], out_frames, WHISPER_SAMPLE_RATE,
VAD_THOLD, FREQ_THOLD,
gf->log_level != LOG_DEBUG);
}
if (!skipped_inference) {
// run inference
const struct DetectionResultWithText inference_result =
run_whisper_inference(gf, output[0], out_frames);
if (inference_result.result == DETECTION_RESULT_SPEECH) {
// output inference result to a text source
gf->setTextCallback(inference_result.text);
} else if (inference_result.result == DETECTION_RESULT_SILENCE) {
// output inference result to a text source
gf->setTextCallback("[silence]");
}
} else {
if (gf->log_words) {
obs_log(LOG_INFO, "skipping inference");
}
gf->setTextCallback("");
}
// end of timer
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
const uint32_t new_frames_from_infos_ms =
num_new_frames_from_infos * 1000 /
gf->sample_rate; // number of frames in this packet
obs_log(gf->log_level, "audio processing of %u ms new data took %d ms",
new_frames_from_infos_ms, (int)duration);
if (duration > new_frames_from_infos_ms) {
// try to decrease overlap down to minimum of 100 ms
gf->overlap_ms = std::max((uint64_t)gf->overlap_ms - 10, (uint64_t)100);
gf->overlap_frames = gf->overlap_ms * gf->sample_rate / 1000;
obs_log(gf->log_level,
"audio processing took too long (%d ms), reducing overlap to %lu ms",
(int)duration, gf->overlap_ms);
} else if (!skipped_inference) {
if (gf->overlap_ms < OVERLAP_SIZE_MSEC) {
// try to increase overlap up to OVERLAP_SIZE_MSEC
gf->overlap_ms = std::min((uint64_t)gf->overlap_ms + 10,
(uint64_t)OVERLAP_SIZE_MSEC);
gf->overlap_frames = gf->overlap_ms * gf->sample_rate / 1000;
obs_log(gf->log_level,
"audio processing took %d ms, increasing overlap to %lu ms",
(int)duration, gf->overlap_ms);
}
}
}
void whisper_loop(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
const size_t segment_size = gf->frames * sizeof(float);
obs_log(LOG_INFO, "starting whisper thread");
// Thread main loop
while (true) {
{
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "Whisper context is null, exiting thread");
break;
}
}
// Check if we have enough data to process
while (true) {
size_t input_buf_size = 0;
{
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
input_buf_size = gf->input_buffers[0].size;
}
if (input_buf_size >= segment_size) {
obs_log(gf->log_level,
"found %lu bytes, %lu frames in input buffer, need >= %lu, processing",
input_buf_size, (size_t)(input_buf_size / sizeof(float)),
segment_size);
// Process the audio. This will also remove the processed data from the input buffer.
// Mutex is locked inside process_audio_from_buffer.
process_audio_from_buffer(gf);
} else {
break;
}
}
// Sleep for 10 ms using the condition variable wshiper_thread_cv
// This will wake up the thread if there is new data in the input buffer
// or if the whisper context is null
std::unique_lock<std::mutex> lock(*gf->whisper_ctx_mutex);
gf->wshiper_thread_cv->wait_for(lock, std::chrono::milliseconds(10));
}
obs_log(LOG_INFO, "exiting whisper thread");
}

14
src/whisper-processing.h Normal file
View File

@@ -0,0 +1,14 @@
#ifndef WHISPER_PROCESSING_H
#define WHISPER_PROCESSING_H
// buffer size in msec
#define BUFFER_SIZE_MSEC 3000
// at 16Khz, 3000 msec is 48000 samples
#define WHISPER_FRAME_SIZE 48000
// overlap in msec
#define OVERLAP_SIZE_MSEC 200
void whisper_loop(void *data);
struct whisper_context *init_whisper_context(const std::string &model_path);
#endif // WHISPER_PROCESSING_H

1
vendor/curl vendored Submodule

Submodule vendor/curl added at 439ff2052e