mirror of
https://github.com/royshil/obs-localvocal.git
synced 2026-01-10 04:48:02 -05:00
Merge pull request #1 from royshil/roy.add_threaded_whisper_cpp
Add threaded whisper cpp
This commit is contained in:
@@ -44,7 +44,7 @@ BreakBeforeBraces: Custom
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializers: BeforeColon
|
||||
BreakStringLiterals: false # apparently unpredictable
|
||||
ColumnLimit: 80
|
||||
ColumnLimit: 100
|
||||
CompactNamespaces: false
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 8
|
||||
@@ -53,7 +53,7 @@ Cpp11BracedListStyle: true
|
||||
DerivePointerAlignment: false
|
||||
DisableFormat: false
|
||||
FixNamespaceComments: false
|
||||
ForEachMacros:
|
||||
ForEachMacros:
|
||||
- 'json_object_foreach'
|
||||
- 'json_object_foreach_safe'
|
||||
- 'json_array_foreach'
|
||||
|
||||
5
.github/actions/package-plugin/action.yaml
vendored
5
.github/actions/package-plugin/action.yaml
vendored
@@ -24,6 +24,10 @@ inputs:
|
||||
description: 'Developer ID for installer package codesigning (macOS only)'
|
||||
required: false
|
||||
default: ''
|
||||
codesignTeam:
|
||||
description: 'Developer team for codesigning (macOS only)'
|
||||
required: false
|
||||
default: ''
|
||||
codesignUser:
|
||||
description: 'Apple ID username for notarization (macOS only)'
|
||||
required: false
|
||||
@@ -50,6 +54,7 @@ runs:
|
||||
env:
|
||||
CODESIGN_IDENT: ${{ inputs.codesignIdent }}
|
||||
CODESIGN_IDENT_INSTALLER: ${{ inputs.installerIdent }}
|
||||
CODESIGN_TEAM: ${{ inputs.codesignTeam }}
|
||||
CODESIGN_IDENT_USER: ${{ inputs.codesignUser }}
|
||||
CODESIGN_IDENT_PASS: ${{ inputs.codesignPass }}
|
||||
run: |
|
||||
|
||||
1
.github/workflows/build-project.yaml
vendored
1
.github/workflows/build-project.yaml
vendored
@@ -129,6 +129,7 @@ jobs:
|
||||
codesign: ${{ fromJSON(needs.check-event.outputs.codesign) && fromJSON(steps.codesign.outputs.haveCodesignIdent) }}
|
||||
codesignIdent: ${{ steps.codesign.outputs.codesignIdent }}
|
||||
installerIdent: ${{ steps.codesign.outputs.installerIdent }}
|
||||
codesignTeam: ${{ steps.codesign.outputs.codesignTeam }}
|
||||
notarize: ${{ fromJSON(needs.check-event.outputs.notarize) && fromJSON(steps.codesign.outputs.haveNotarizationUser) }}
|
||||
codesignUser: ${{ secrets.MACOS_NOTARIZATION_USERNAME }}
|
||||
codesignPass: ${{ secrets.MACOS_NOTARIZATION_PASSWORD }}
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -15,6 +15,8 @@
|
||||
!CMakePresets.json
|
||||
!LICENSE
|
||||
!README.md
|
||||
!/vendor
|
||||
!patch_libobs.diff
|
||||
|
||||
# Exclude lock files
|
||||
*.lock.json
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "vendor/curl"]
|
||||
path = vendor/curl
|
||||
url = https://github.com/curl/curl.git
|
||||
@@ -4,8 +4,8 @@ include("${CMAKE_CURRENT_SOURCE_DIR}/cmake/common/bootstrap.cmake" NO_POLICY_SCO
|
||||
|
||||
project(${_name} VERSION ${_version})
|
||||
|
||||
option(ENABLE_FRONTEND_API "Use obs-frontend-api for UI functionality" OFF)
|
||||
option(ENABLE_QT "Use Qt functionality" OFF)
|
||||
option(ENABLE_FRONTEND_API "Use obs-frontend-api for UI functionality" ON)
|
||||
option(ENABLE_QT "Use Qt functionality" ON)
|
||||
|
||||
include(compilerconfig)
|
||||
include(defaults)
|
||||
@@ -34,6 +34,25 @@ if(ENABLE_QT)
|
||||
AUTORCC ON)
|
||||
endif()
|
||||
|
||||
target_sources(${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c)
|
||||
set(USE_SYSTEM_CURL
|
||||
OFF
|
||||
CACHE STRING "Use system cURL")
|
||||
|
||||
if(USE_SYSTEM_CURL)
|
||||
find_package(CURL REQUIRED)
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${CURL_LIBRARIES}")
|
||||
target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${CURL_INCLUDE_DIRS}")
|
||||
else()
|
||||
include(cmake/BuildMyCurl.cmake)
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE libcurl)
|
||||
endif()
|
||||
|
||||
include(cmake/BuildWhispercpp.cmake)
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp)
|
||||
|
||||
target_sources(
|
||||
${CMAKE_PROJECT_NAME}
|
||||
PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c src/whisper-processing.cpp
|
||||
src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp)
|
||||
|
||||
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})
|
||||
|
||||
@@ -23,8 +23,8 @@
|
||||
"CMAKE_OSX_DEPLOYMENT_TARGET": "11.0",
|
||||
"CODESIGN_IDENTITY": "$penv{CODESIGN_IDENT}",
|
||||
"CODESIGN_TEAM": "$penv{CODESIGN_TEAM}",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -53,8 +53,8 @@
|
||||
"cacheVariables": {
|
||||
"QT_VERSION": "6",
|
||||
"CMAKE_SYSTEM_VERSION": "10.0.18363.657",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -81,8 +81,8 @@
|
||||
"cacheVariables": {
|
||||
"QT_VERSION": "6",
|
||||
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -110,8 +110,8 @@
|
||||
"cacheVariables": {
|
||||
"QT_VERSION": "6",
|
||||
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -41,14 +41,14 @@
|
||||
},
|
||||
"platformConfig": {
|
||||
"macos": {
|
||||
"bundleId": "com.example.obs-plugintemplate"
|
||||
"bundleId": "com.royshilkrot.obs-localvocal"
|
||||
}
|
||||
},
|
||||
"name": "obs-plugintemplate",
|
||||
"version": "1.0.0",
|
||||
"author": "Your Name Here",
|
||||
"website": "https://example.com",
|
||||
"email": "me@example.com",
|
||||
"name": "obs-localvocal",
|
||||
"version": "0.0.1",
|
||||
"author": "Roy Shilkrot",
|
||||
"website": "https://github.com/royshil/obs-localvocal",
|
||||
"email": "roy.shil@gmail.com",
|
||||
"uuids": {
|
||||
"macosPackage": "00000000-0000-0000-0000-000000000000",
|
||||
"macosInstaller": "00000000-0000-0000-0000-000000000000",
|
||||
|
||||
29
cmake/BuildMyCurl.cmake
Normal file
29
cmake/BuildMyCurl.cmake
Normal file
@@ -0,0 +1,29 @@
|
||||
set(LIBCURL_SOURCE_DIR ${CMAKE_SOURCE_DIR}/vendor/curl)
|
||||
|
||||
find_package(Git QUIET)
|
||||
execute_process(
|
||||
COMMAND ${GIT_EXECUTABLE} checkout curl-8_2_0
|
||||
WORKING_DIRECTORY ${LIBCURL_SOURCE_DIR}
|
||||
RESULT_VARIABLE GIT_SUBMOD_RESULT)
|
||||
|
||||
if(OS_MACOS)
|
||||
set(CURL_USE_OPENSSL OFF)
|
||||
set(CURL_USE_SECTRANSP ON)
|
||||
elseif(OS_WINDOWS)
|
||||
set(CURL_USE_OPENSSL OFF)
|
||||
set(CURL_USE_SCHANNEL ON)
|
||||
elseif(OS_LINUX)
|
||||
add_compile_options(-fPIC)
|
||||
set(CURL_USE_OPENSSL ON)
|
||||
endif()
|
||||
set(BUILD_CURL_EXE OFF)
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
set(HTTP_ONLY OFF)
|
||||
set(CURL_USE_LIBSSH2 OFF)
|
||||
add_subdirectory(${LIBCURL_SOURCE_DIR} EXCLUDE_FROM_ALL)
|
||||
if(OS_MACOS)
|
||||
target_compile_options(
|
||||
libcurl PRIVATE -Wno-error=ambiguous-macro -Wno-error=deprecated-declarations -Wno-error=unreachable-code
|
||||
-Wno-error=unused-parameter -Wno-error=unused-variable)
|
||||
endif()
|
||||
include_directories(SYSTEM ${LIBCURL_SOURCE_DIR}/include)
|
||||
51
cmake/BuildWhispercpp.cmake
Normal file
51
cmake/BuildWhispercpp.cmake
Normal file
@@ -0,0 +1,51 @@
|
||||
include(ExternalProject)
|
||||
|
||||
set(CMAKE_OSX_ARCHITECTURES_ "arm64$<SEMICOLON>x86_64")
|
||||
|
||||
if(${CMAKE_BUILD_TYPE} STREQUAL Release OR ${CMAKE_BUILD_TYPE} STREQUAL RelWithDebInfo)
|
||||
set(Whispercpp_BUILD_TYPE Release)
|
||||
else()
|
||||
set(Whispercpp_BUILD_TYPE Debug)
|
||||
endif()
|
||||
|
||||
# On linux add the `-fPIC` flag to the compiler
|
||||
if(UNIX AND NOT APPLE)
|
||||
set(WHISPER_EXTRA_CXX_FLAGS "-fPIC")
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(
|
||||
Whispercpp_Build
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP true
|
||||
GIT_REPOSITORY https://github.com/ggerganov/whisper.cpp.git
|
||||
GIT_TAG 7b374c9ac9b9861bb737eec060e4dfa29d229259
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build <BINARY_DIR> --config ${Whispercpp_BUILD_TYPE}
|
||||
BUILD_BYPRODUCTS <INSTALL_DIR>/lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX}
|
||||
CMAKE_GENERATOR ${CMAKE_GENERATOR}
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install <BINARY_DIR> --config ${Whispercpp_BUILD_TYPE}
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
|
||||
-DCMAKE_BUILD_TYPE=${Whispercpp_BUILD_TYPE}
|
||||
-DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=10.13
|
||||
-DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES_}
|
||||
-DCMAKE_CXX_FLAGS=${WHISPER_EXTRA_CXX_FLAGS}
|
||||
-DCMAKE_C_FLAGS=${WHISPER_EXTRA_CXX_FLAGS}
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
-DWHISPER_BUILD_TESTS=OFF
|
||||
-DWHISPER_BUILD_EXAMPLES=OFF
|
||||
-DWHISPER_OPENBLAS=ON)
|
||||
|
||||
ExternalProject_Get_Property(Whispercpp_Build INSTALL_DIR)
|
||||
|
||||
add_library(Whispercpp::Whisper STATIC IMPORTED)
|
||||
set_target_properties(
|
||||
Whispercpp::Whisper
|
||||
PROPERTIES IMPORTED_LOCATION
|
||||
${INSTALL_DIR}/lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX})
|
||||
|
||||
add_library(Whispercpp INTERFACE)
|
||||
add_dependencies(Whispercpp Whispercpp_Build)
|
||||
target_link_libraries(Whispercpp INTERFACE Whispercpp::Whisper)
|
||||
set_target_properties(Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include)
|
||||
if(APPLE)
|
||||
target_link_libraries(Whispercpp INTERFACE "-framework Accelerate")
|
||||
endif(APPLE)
|
||||
@@ -73,6 +73,13 @@ function(_setup_obs_studio)
|
||||
set(_cmake_version "3.0.0")
|
||||
endif()
|
||||
|
||||
message(STATUS "Patch libobs")
|
||||
execute_process(
|
||||
COMMAND patch --forward "libobs/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/patch_libobs.diff"
|
||||
RESULT_VARIABLE _process_result
|
||||
WORKING_DIRECTORY "${dependencies_dir}/${_obs_destination}")
|
||||
message(STATUS "Patch - done")
|
||||
|
||||
message(STATUS "Configure ${label} (${arch})")
|
||||
execute_process(
|
||||
COMMAND
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
transcription_filterAudioFilter=LocalVocal Transcription
|
||||
|
||||
BIN
data/models/ggml-tiny.en.bin
Normal file
BIN
data/models/ggml-tiny.en.bin
Normal file
Binary file not shown.
20
patch_libobs.diff
Normal file
20
patch_libobs.diff
Normal file
@@ -0,0 +1,20 @@
|
||||
diff --git a/libobs/CMakeLists.txt b/libobs/CMakeLists.txt
|
||||
index d2e2671..5a9242a 100644
|
||||
--- a/libobs/CMakeLists.txt
|
||||
+++ b/libobs/CMakeLists.txt
|
||||
@@ -263,6 +263,7 @@ set(public_headers
|
||||
graphics/vec3.h
|
||||
graphics/vec4.h
|
||||
media-io/audio-io.h
|
||||
+ media-io/audio-resampler.h
|
||||
media-io/frame-rate.h
|
||||
media-io/media-io-defs.h
|
||||
media-io/video-io.h
|
||||
@@ -287,6 +288,7 @@ set(public_headers
|
||||
util/base.h
|
||||
util/bmem.h
|
||||
util/c99defs.h
|
||||
+ util/circlebuf.h
|
||||
util/darray.h
|
||||
util/profiler.h
|
||||
util/sse-intrin.h
|
||||
180
src/model-utils/model-downloader-ui.cpp
Normal file
180
src/model-utils/model-downloader-ui.cpp
Normal file
@@ -0,0 +1,180 @@
|
||||
#include "model-downloader-ui.h"
|
||||
#include "plugin-support.h"
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
const std::string MODEL_BASE_PATH = "https://huggingface.co/ggerganov/whisper.cpp";
|
||||
const std::string MODEL_PREFIX = "resolve/main/";
|
||||
|
||||
size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream)
|
||||
{
|
||||
size_t written = fwrite(ptr, size, nmemb, stream);
|
||||
return written;
|
||||
}
|
||||
|
||||
ModelDownloader::ModelDownloader(
|
||||
const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback_, QWidget *parent)
|
||||
: QDialog(parent), download_finished_callback(download_finished_callback_)
|
||||
{
|
||||
this->setWindowTitle("Downloading model...");
|
||||
this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint);
|
||||
this->setFixedSize(300, 100);
|
||||
|
||||
this->layout = new QVBoxLayout(this);
|
||||
|
||||
// Add a label for the model name
|
||||
QLabel *model_name_label = new QLabel(this);
|
||||
model_name_label->setText(QString::fromStdString(model_name));
|
||||
model_name_label->setAlignment(Qt::AlignCenter);
|
||||
this->layout->addWidget(model_name_label);
|
||||
|
||||
this->progress_bar = new QProgressBar(this);
|
||||
this->progress_bar->setRange(0, 100);
|
||||
this->progress_bar->setValue(0);
|
||||
this->progress_bar->setAlignment(Qt::AlignCenter);
|
||||
// Show progress as a percentage
|
||||
this->progress_bar->setFormat("%p%");
|
||||
this->layout->addWidget(this->progress_bar);
|
||||
|
||||
this->download_thread = new QThread();
|
||||
this->download_worker = new ModelDownloadWorker(model_name);
|
||||
this->download_worker->moveToThread(this->download_thread);
|
||||
|
||||
connect(this->download_thread, &QThread::started, this->download_worker,
|
||||
&ModelDownloadWorker::download_model);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_progress, this,
|
||||
&ModelDownloader::update_progress);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_finished, this,
|
||||
&ModelDownloader::download_finished);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_finished,
|
||||
this->download_thread, &QThread::quit);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_finished,
|
||||
this->download_worker, &ModelDownloadWorker::deleteLater);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_error, this,
|
||||
&ModelDownloader::show_error);
|
||||
connect(this->download_thread, &QThread::finished, this->download_thread,
|
||||
&QThread::deleteLater);
|
||||
|
||||
this->download_thread->start();
|
||||
}
|
||||
|
||||
void ModelDownloader::update_progress(int progress)
|
||||
{
|
||||
this->progress_bar->setValue(progress);
|
||||
}
|
||||
|
||||
void ModelDownloader::download_finished()
|
||||
{
|
||||
this->setWindowTitle("Download finished!");
|
||||
this->progress_bar->setValue(100);
|
||||
this->progress_bar->setFormat("Download finished!");
|
||||
this->progress_bar->setAlignment(Qt::AlignCenter);
|
||||
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #05B8CC; }");
|
||||
// Add a button to close the dialog
|
||||
QPushButton *close_button = new QPushButton("Close", this);
|
||||
this->layout->addWidget(close_button);
|
||||
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
|
||||
// Call the callback
|
||||
this->download_finished_callback(0);
|
||||
}
|
||||
|
||||
void ModelDownloader::show_error(const std::string &reason)
|
||||
{
|
||||
this->setWindowTitle("Download failed!");
|
||||
this->progress_bar->setFormat("Download failed!");
|
||||
this->progress_bar->setAlignment(Qt::AlignCenter);
|
||||
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }");
|
||||
// Add a label to show the error
|
||||
QLabel *error_label = new QLabel(this);
|
||||
error_label->setText(QString::fromStdString(reason));
|
||||
error_label->setAlignment(Qt::AlignCenter);
|
||||
// Color red
|
||||
error_label->setStyleSheet("QLabel { color : red; }");
|
||||
this->layout->addWidget(error_label);
|
||||
// Add a button to close the dialog
|
||||
QPushButton *close_button = new QPushButton("Close", this);
|
||||
this->layout->addWidget(close_button);
|
||||
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
|
||||
this->download_finished_callback(1);
|
||||
}
|
||||
|
||||
ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_)
|
||||
{
|
||||
this->model_name = model_name_;
|
||||
}
|
||||
|
||||
void ModelDownloadWorker::download_model()
|
||||
{
|
||||
std::string module_data_dir = obs_get_module_data_path(obs_current_module());
|
||||
// join the directory and the filename using the platform-specific separator
|
||||
std::string model_save_path = module_data_dir + "/" + this->model_name;
|
||||
obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str());
|
||||
|
||||
// extract filename from path in this->modle_name
|
||||
const std::string model_filename =
|
||||
this->model_name.substr(this->model_name.find_last_of("/\\") + 1);
|
||||
|
||||
std::string model_url = MODEL_BASE_PATH + "/" + MODEL_PREFIX + model_filename;
|
||||
obs_log(LOG_INFO, "Model URL: %s", model_url.c_str());
|
||||
|
||||
CURL *curl = curl_easy_init();
|
||||
if (curl) {
|
||||
FILE *fp = fopen(model_save_path.c_str(), "wb");
|
||||
if (fp == nullptr) {
|
||||
obs_log(LOG_ERROR, "Failed to open file %s.", model_save_path.c_str());
|
||||
emit download_error("Failed to open file.");
|
||||
return;
|
||||
}
|
||||
curl_easy_setopt(curl, CURLOPT_URL, model_url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
|
||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
|
||||
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION,
|
||||
ModelDownloadWorker::progress_callback);
|
||||
curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this);
|
||||
// Follow redirects
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
if (res != CURLE_OK) {
|
||||
obs_log(LOG_ERROR, "Failed to download model %s.",
|
||||
this->model_name.c_str());
|
||||
emit download_error("Failed to download model.");
|
||||
}
|
||||
curl_easy_cleanup(curl);
|
||||
fclose(fp);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Failed to initialize curl.");
|
||||
emit download_error("Failed to initialize curl.");
|
||||
}
|
||||
emit download_finished();
|
||||
}
|
||||
|
||||
int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
|
||||
curl_off_t, curl_off_t)
|
||||
{
|
||||
if (dltotal == 0) {
|
||||
return 0; // Unknown progress
|
||||
}
|
||||
ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp;
|
||||
if (worker == nullptr) {
|
||||
obs_log(LOG_ERROR, "Worker is null.");
|
||||
return 1;
|
||||
}
|
||||
int progress = (int)(dlnow * 100l / dltotal);
|
||||
emit worker->download_progress(progress);
|
||||
return 0;
|
||||
}
|
||||
|
||||
ModelDownloader::~ModelDownloader()
|
||||
{
|
||||
this->download_thread->quit();
|
||||
this->download_thread->wait();
|
||||
delete this->download_thread;
|
||||
delete this->download_worker;
|
||||
}
|
||||
|
||||
ModelDownloadWorker::~ModelDownloadWorker()
|
||||
{
|
||||
// Do nothing
|
||||
}
|
||||
54
src/model-utils/model-downloader-ui.h
Normal file
54
src/model-utils/model-downloader-ui.h
Normal file
@@ -0,0 +1,54 @@
|
||||
#ifndef MODEL_DOWNLOADER_UI_H
|
||||
#define MODEL_DOWNLOADER_UI_H
|
||||
|
||||
#include <QtWidgets>
|
||||
#include <QThread>
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
class ModelDownloadWorker : public QObject {
|
||||
Q_OBJECT
|
||||
public:
|
||||
ModelDownloadWorker(const std::string &model_name);
|
||||
~ModelDownloadWorker();
|
||||
|
||||
public slots:
|
||||
void download_model();
|
||||
|
||||
signals:
|
||||
void download_progress(int progress);
|
||||
void download_finished();
|
||||
void download_error(const std::string &reason);
|
||||
|
||||
private:
|
||||
static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
|
||||
curl_off_t ultotal, curl_off_t ulnow);
|
||||
std::string model_name;
|
||||
};
|
||||
|
||||
class ModelDownloader : public QDialog {
|
||||
Q_OBJECT
|
||||
public:
|
||||
ModelDownloader(const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback,
|
||||
QWidget *parent = nullptr);
|
||||
~ModelDownloader();
|
||||
|
||||
public slots:
|
||||
void update_progress(int progress);
|
||||
void download_finished();
|
||||
void show_error(const std::string &reason);
|
||||
|
||||
private:
|
||||
QVBoxLayout *layout;
|
||||
QProgressBar *progress_bar;
|
||||
QThread *download_thread;
|
||||
ModelDownloadWorker *download_worker;
|
||||
// Callback for when the download is finished
|
||||
std::function<void(int download_status)> download_finished_callback;
|
||||
};
|
||||
|
||||
#endif // MODEL_DOWNLOADER_UI_H
|
||||
42
src/model-utils/model-downloader.cpp
Normal file
42
src/model-utils/model-downloader.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include "model-downloader.h"
|
||||
#include "plugin-support.h"
|
||||
#include "model-downloader-ui.h"
|
||||
|
||||
#include <obs-module.h>
|
||||
#include <obs-frontend-api.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
bool check_if_model_exists(const std::string &model_name)
|
||||
{
|
||||
obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str());
|
||||
char *model_file_path = obs_module_file(model_name.c_str());
|
||||
obs_log(LOG_INFO, "Model file path: %s", model_file_path);
|
||||
if (model_file_path == nullptr) {
|
||||
obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(model_file_path)) {
|
||||
obs_log(LOG_INFO, "Model %s does not exist.", model_file_path);
|
||||
bfree(model_file_path);
|
||||
return false;
|
||||
}
|
||||
bfree(model_file_path);
|
||||
return true;
|
||||
}
|
||||
|
||||
void download_model_with_ui_dialog(
|
||||
const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback)
|
||||
{
|
||||
// Start the model downloader UI
|
||||
ModelDownloader *model_downloader = new ModelDownloader(
|
||||
model_name, download_finished_callback, (QWidget *)obs_frontend_get_main_window());
|
||||
model_downloader->show();
|
||||
}
|
||||
14
src/model-utils/model-downloader.h
Normal file
14
src/model-utils/model-downloader.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef MODEL_DOWNLOADER_H
|
||||
#define MODEL_DOWNLOADER_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
bool check_if_model_exists(const std::string &model_name);
|
||||
|
||||
// Start the model downloader UI dialog with a callback for when the download is finished
|
||||
void download_model_with_ui_dialog(
|
||||
const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback);
|
||||
|
||||
#endif // MODEL_DOWNLOADER_H
|
||||
@@ -22,10 +22,17 @@ with this program. If not, see <https://www.gnu.org/licenses/>
|
||||
OBS_DECLARE_MODULE()
|
||||
OBS_MODULE_USE_DEFAULT_LOCALE(PLUGIN_NAME, "en-US")
|
||||
|
||||
MODULE_EXPORT const char *obs_module_description(void)
|
||||
{
|
||||
return obs_module_text("LocalVocalPlugin");
|
||||
}
|
||||
|
||||
extern struct obs_source_info transcription_filter_info;
|
||||
|
||||
bool obs_module_load(void)
|
||||
{
|
||||
obs_log(LOG_INFO, "plugin loaded successfully (version %s)",
|
||||
PLUGIN_VERSION);
|
||||
obs_register_source(&transcription_filter_info);
|
||||
blog(LOG_INFO, "plugin loaded successfully (version %s)", PLUGIN_VERSION);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ with this program. If not, see <https://www.gnu.org/licenses/>
|
||||
const char *PLUGIN_NAME = "@CMAKE_PROJECT_NAME@";
|
||||
const char *PLUGIN_VERSION = "@CMAKE_PROJECT_VERSION@";
|
||||
|
||||
extern void blogva(int log_level, const char *format, va_list args);
|
||||
|
||||
void obs_log(int log_level, const char *format, ...)
|
||||
{
|
||||
size_t length = 4 + strlen(PLUGIN_NAME) + strlen(format);
|
||||
|
||||
@@ -31,7 +31,6 @@ extern const char *PLUGIN_NAME;
|
||||
extern const char *PLUGIN_VERSION;
|
||||
|
||||
void obs_log(int log_level, const char *format, ...);
|
||||
extern void blogva(int log_level, const char *format, va_list args);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
76
src/transcription-filter-data.h
Normal file
76
src/transcription-filter-data.h
Normal file
@@ -0,0 +1,76 @@
|
||||
#ifndef TRANSCRIPTION_FILTER_DATA_H
|
||||
#define TRANSCRIPTION_FILTER_DATA_H
|
||||
|
||||
#include <obs.h>
|
||||
#include <util/circlebuf.h>
|
||||
#include <util/darray.h>
|
||||
#include <media-io/audio-resampler.h>
|
||||
|
||||
#include <whisper.h>
|
||||
|
||||
#include <thread>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#define MAX_PREPROC_CHANNELS 2
|
||||
|
||||
#define MT_ obs_module_text
|
||||
|
||||
struct transcription_filter_data {
|
||||
obs_source_t *context; // obs input source
|
||||
size_t channels; // number of channels
|
||||
uint32_t sample_rate; // input sample rate
|
||||
// How many input frames (in input sample rate) are needed for the next whisper frame
|
||||
size_t frames;
|
||||
// How many ms/frames are needed to overlap with the next whisper frame
|
||||
size_t overlap_frames;
|
||||
size_t overlap_ms;
|
||||
// How many frames were processed in the last whisper frame (this is dynamic)
|
||||
size_t last_num_frames;
|
||||
|
||||
/* PCM buffers */
|
||||
float *copy_buffers[MAX_PREPROC_CHANNELS];
|
||||
struct circlebuf info_buffer;
|
||||
struct circlebuf input_buffers[MAX_PREPROC_CHANNELS];
|
||||
|
||||
/* Resampler */
|
||||
audio_resampler_t *resampler;
|
||||
|
||||
/* whisper */
|
||||
std::string whisper_model_path = "models/ggml-tiny.en.bin";
|
||||
struct whisper_context *whisper_context;
|
||||
whisper_full_params whisper_params;
|
||||
|
||||
float filler_p_threshold;
|
||||
|
||||
bool do_silence;
|
||||
bool vad_enabled;
|
||||
int log_level;
|
||||
bool log_words;
|
||||
bool active;
|
||||
|
||||
// Text source to output the subtitles
|
||||
obs_weak_source_t *text_source;
|
||||
char *text_source_name;
|
||||
std::unique_ptr<std::mutex> text_source_mutex;
|
||||
// Callback to set the text in the output text source (subtitles)
|
||||
std::function<void(const std::string &str)> setTextCallback;
|
||||
|
||||
// Use std for thread and mutex
|
||||
std::thread whisper_thread;
|
||||
|
||||
std::unique_ptr<std::mutex> whisper_buf_mutex;
|
||||
std::unique_ptr<std::mutex> whisper_ctx_mutex;
|
||||
std::unique_ptr<std::condition_variable> wshiper_thread_cv;
|
||||
};
|
||||
|
||||
// Audio packet info
|
||||
struct transcription_filter_audio_info {
|
||||
uint32_t frames;
|
||||
uint64_t timestamp;
|
||||
};
|
||||
|
||||
#endif /* TRANSCRIPTION_FILTER_DATA_H */
|
||||
16
src/transcription-filter.c
Normal file
16
src/transcription-filter.c
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "transcription-filter.h"
|
||||
|
||||
struct obs_source_info transcription_filter_info = {
|
||||
.id = "transcription_filter_audio_filter",
|
||||
.type = OBS_SOURCE_TYPE_FILTER,
|
||||
.output_flags = OBS_SOURCE_AUDIO,
|
||||
.get_name = transcription_filter_name,
|
||||
.create = transcription_filter_create,
|
||||
.destroy = transcription_filter_destroy,
|
||||
.get_defaults = transcription_filter_defaults,
|
||||
.get_properties = transcription_filter_properties,
|
||||
.update = transcription_filter_update,
|
||||
.activate = transcription_filter_activate,
|
||||
.deactivate = transcription_filter_deactivate,
|
||||
.filter_audio = transcription_filter_filter_audio,
|
||||
};
|
||||
536
src/transcription-filter.cpp
Normal file
536
src/transcription-filter.cpp
Normal file
@@ -0,0 +1,536 @@
|
||||
#include <obs-module.h>
|
||||
|
||||
#include "plugin-support.h"
|
||||
#include "transcription-filter.h"
|
||||
#include "transcription-filter-data.h"
|
||||
#include "whisper-processing.h"
|
||||
#include "whisper-language.h"
|
||||
#include "model-utils/model-downloader.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
|
||||
{
|
||||
switch (channels) {
|
||||
case 0:
|
||||
return SPEAKERS_UNKNOWN;
|
||||
case 1:
|
||||
return SPEAKERS_MONO;
|
||||
case 2:
|
||||
return SPEAKERS_STEREO;
|
||||
case 3:
|
||||
return SPEAKERS_2POINT1;
|
||||
case 4:
|
||||
return SPEAKERS_4POINT0;
|
||||
case 5:
|
||||
return SPEAKERS_4POINT1;
|
||||
case 6:
|
||||
return SPEAKERS_5POINT1;
|
||||
case 8:
|
||||
return SPEAKERS_7POINT1;
|
||||
default:
|
||||
return SPEAKERS_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
bool add_sources_to_list(void *list_property, obs_source_t *source)
|
||||
{
|
||||
auto source_id = obs_source_get_id(source);
|
||||
if (strcmp(source_id, "text_ft2_source_v2") != 0 &&
|
||||
strcmp(source_id, "text_gdiplus_v2") != 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
obs_property_t *sources = (obs_property_t *)list_property;
|
||||
const char *name = obs_source_get_name(source);
|
||||
obs_property_list_add_string(sources, name, name);
|
||||
return true;
|
||||
}
|
||||
|
||||
struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_audio_data *audio)
|
||||
{
|
||||
if (!audio) {
|
||||
return nullptr;
|
||||
}
|
||||
if (data == nullptr) {
|
||||
return audio;
|
||||
}
|
||||
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
|
||||
if (!gf->active) {
|
||||
return audio;
|
||||
}
|
||||
|
||||
if (gf->whisper_context == nullptr) {
|
||||
// Whisper not initialized, just pass through
|
||||
return audio;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex); // scoped lock
|
||||
obs_log(gf->log_level,
|
||||
"pushing %lu frames to input buffer. current size: %lu (bytes)",
|
||||
(size_t)(audio->frames), gf->input_buffers[0].size);
|
||||
// push back current audio data to input circlebuf
|
||||
for (size_t c = 0; c < gf->channels; c++) {
|
||||
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
|
||||
audio->frames * sizeof(float));
|
||||
}
|
||||
// push audio packet info (timestamp/frame count) to info circlebuf
|
||||
struct transcription_filter_audio_info info = {0};
|
||||
info.frames = audio->frames; // number of frames in this packet
|
||||
info.timestamp = audio->timestamp; // timestamp of this packet
|
||||
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
|
||||
}
|
||||
|
||||
return audio;
|
||||
}
|
||||
|
||||
const char *transcription_filter_name(void *unused)
|
||||
{
|
||||
UNUSED_PARAMETER(unused);
|
||||
return MT_("transcription_filterAudioFilter");
|
||||
}
|
||||
|
||||
void transcription_filter_destroy(void *data)
|
||||
{
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
|
||||
obs_log(LOG_INFO, "transcription_filter_destroy");
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context != nullptr) {
|
||||
whisper_free(gf->whisper_context);
|
||||
gf->whisper_context = nullptr;
|
||||
gf->wshiper_thread_cv->notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
// join the thread
|
||||
if (gf->whisper_thread.joinable()) {
|
||||
gf->whisper_thread.join();
|
||||
}
|
||||
|
||||
if (gf->text_source_name) {
|
||||
bfree(gf->text_source_name);
|
||||
gf->text_source_name = nullptr;
|
||||
}
|
||||
|
||||
if (gf->text_source) {
|
||||
obs_weak_source_release(gf->text_source);
|
||||
gf->text_source = nullptr;
|
||||
}
|
||||
|
||||
if (gf->resampler) {
|
||||
audio_resampler_destroy(gf->resampler);
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lockbuf(*gf->whisper_buf_mutex);
|
||||
bfree(gf->copy_buffers[0]);
|
||||
gf->copy_buffers[0] = nullptr;
|
||||
for (size_t i = 0; i < gf->channels; i++) {
|
||||
circlebuf_free(&gf->input_buffers[i]);
|
||||
}
|
||||
}
|
||||
circlebuf_free(&gf->info_buffer);
|
||||
|
||||
bfree(gf);
|
||||
}
|
||||
|
||||
void acquire_weak_text_source_ref(struct transcription_filter_data *gf)
|
||||
{
|
||||
if (!gf->text_source_name) {
|
||||
obs_log(LOG_ERROR, "text_source_name is null");
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
|
||||
// acquire a weak ref to the new text source
|
||||
obs_source_t *source = obs_get_source_by_name(gf->text_source_name);
|
||||
if (source) {
|
||||
gf->text_source = obs_source_get_weak_source(source);
|
||||
obs_source_release(source);
|
||||
if (!gf->text_source) {
|
||||
obs_log(LOG_ERROR, "failed to get weak source for text source %s",
|
||||
gf->text_source_name);
|
||||
}
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "text source '%s' not found", gf->text_source_name);
|
||||
}
|
||||
}
|
||||
|
||||
void transcription_filter_update(void *data, obs_data_t *s)
|
||||
{
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
|
||||
gf->log_level = (int)obs_data_get_int(s, "log_level");
|
||||
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
|
||||
gf->log_words = obs_data_get_bool(s, "log_words");
|
||||
|
||||
// update the text source
|
||||
const char *text_source_name = obs_data_get_string(s, "subtitle_sources");
|
||||
obs_weak_source_t *old_weak_text_source = NULL;
|
||||
|
||||
if (strcmp(text_source_name, "none") == 0 || strcmp(text_source_name, "(null)") == 0) {
|
||||
// new selected text source is not valid, release the old one
|
||||
if (gf->text_source) {
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
old_weak_text_source = gf->text_source;
|
||||
gf->text_source = nullptr;
|
||||
}
|
||||
if (gf->text_source_name) {
|
||||
bfree(gf->text_source_name);
|
||||
gf->text_source_name = nullptr;
|
||||
}
|
||||
} else {
|
||||
// new selected text source is valid, check if it's different from the old one
|
||||
if (gf->text_source_name == nullptr ||
|
||||
strcmp(text_source_name, gf->text_source_name) != 0) {
|
||||
// new text source is different from the old one, release the old one
|
||||
if (gf->text_source) {
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
old_weak_text_source = gf->text_source;
|
||||
gf->text_source = nullptr;
|
||||
}
|
||||
gf->text_source_name = bstrdup(text_source_name);
|
||||
}
|
||||
}
|
||||
|
||||
if (old_weak_text_source) {
|
||||
obs_weak_source_release(old_weak_text_source);
|
||||
}
|
||||
|
||||
const char *new_model_path = obs_data_get_string(s, "whisper_model_path");
|
||||
if (strcmp(new_model_path, gf->whisper_model_path.c_str()) != 0) {
|
||||
// model path changed, reload the model
|
||||
obs_log(LOG_INFO, "model path changed, reloading model");
|
||||
if (gf->whisper_context != nullptr) {
|
||||
// acquire the mutex before freeing the context
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
whisper_free(gf->whisper_context);
|
||||
gf->whisper_context = nullptr;
|
||||
gf->wshiper_thread_cv->notify_all();
|
||||
}
|
||||
if (gf->whisper_thread.joinable()) {
|
||||
gf->whisper_thread.join();
|
||||
}
|
||||
gf->whisper_model_path = bstrdup(new_model_path);
|
||||
|
||||
// check if the model exists, if not, download it
|
||||
if (!check_if_model_exists(gf->whisper_model_path)) {
|
||||
obs_log(LOG_ERROR, "Whisper model does not exist");
|
||||
download_model_with_ui_dialog(
|
||||
gf->whisper_model_path, [gf](int download_status) {
|
||||
if (download_status == 0) {
|
||||
obs_log(LOG_INFO, "Model download complete");
|
||||
gf->whisper_context = init_whisper_context(
|
||||
gf->whisper_model_path);
|
||||
gf->whisper_thread = std::thread(whisper_loop, gf);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model download failed");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Model exists, just load it
|
||||
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
|
||||
gf->whisper_thread = std::thread(whisper_loop, gf);
|
||||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
|
||||
gf->whisper_params = whisper_full_default_params(
|
||||
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
|
||||
gf->whisper_params.duration_ms = BUFFER_SIZE_MSEC;
|
||||
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
|
||||
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
|
||||
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
|
||||
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
|
||||
gf->whisper_params.translate = obs_data_get_bool(s, "translate");
|
||||
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
|
||||
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
|
||||
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
|
||||
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
|
||||
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
|
||||
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
|
||||
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
|
||||
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
|
||||
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
|
||||
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
|
||||
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
|
||||
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
|
||||
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
|
||||
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
|
||||
gf->whisper_params.suppress_non_speech_tokens =
|
||||
obs_data_get_bool(s, "suppress_non_speech_tokens");
|
||||
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
|
||||
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
|
||||
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");
|
||||
}
|
||||
|
||||
void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
{
|
||||
struct transcription_filter_data *gf = static_cast<struct transcription_filter_data *>(
|
||||
bmalloc(sizeof(struct transcription_filter_data)));
|
||||
|
||||
// Get the number of channels for the input source
|
||||
gf->channels = audio_output_get_channels(obs_get_audio());
|
||||
gf->sample_rate = audio_output_get_sample_rate(obs_get_audio());
|
||||
gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)BUFFER_SIZE_MSEC));
|
||||
gf->last_num_frames = 0;
|
||||
|
||||
for (size_t i = 0; i < MAX_AUDIO_CHANNELS; i++) {
|
||||
circlebuf_init(&gf->input_buffers[i]);
|
||||
}
|
||||
circlebuf_init(&gf->info_buffer);
|
||||
|
||||
// allocate copy buffers
|
||||
gf->copy_buffers[0] =
|
||||
static_cast<float *>(bzalloc(gf->channels * gf->frames * sizeof(float)));
|
||||
for (size_t c = 1; c < gf->channels; c++) { // set the channel pointers
|
||||
gf->copy_buffers[c] = gf->copy_buffers[0] + c * gf->frames;
|
||||
}
|
||||
|
||||
gf->context = filter;
|
||||
gf->whisper_model_path = obs_data_get_string(settings, "whisper_model_path");
|
||||
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_ERROR, "Failed to load whisper model");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gf->overlap_ms = OVERLAP_SIZE_MSEC;
|
||||
gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
|
||||
obs_log(LOG_INFO, "transcription_filter filter: channels %d, frames %d, sample_rate %d",
|
||||
(int)gf->channels, (int)gf->frames, gf->sample_rate);
|
||||
|
||||
struct resample_info src, dst;
|
||||
src.samples_per_sec = gf->sample_rate;
|
||||
src.format = AUDIO_FORMAT_FLOAT_PLANAR;
|
||||
src.speakers = convert_speaker_layout((uint8_t)gf->channels);
|
||||
|
||||
dst.samples_per_sec = WHISPER_SAMPLE_RATE;
|
||||
dst.format = AUDIO_FORMAT_FLOAT_PLANAR;
|
||||
dst.speakers = convert_speaker_layout((uint8_t)1);
|
||||
|
||||
gf->resampler = audio_resampler_create(&dst, &src);
|
||||
|
||||
gf->active = true;
|
||||
|
||||
gf->whisper_buf_mutex = std::unique_ptr<std::mutex>(new std::mutex());
|
||||
gf->whisper_ctx_mutex = std::unique_ptr<std::mutex>(new std::mutex());
|
||||
gf->wshiper_thread_cv =
|
||||
std::unique_ptr<std::condition_variable>(new std::condition_variable());
|
||||
gf->text_source_mutex = std::unique_ptr<std::mutex>(new std::mutex());
|
||||
|
||||
// set the callback to set the text in the output text source (subtitles)
|
||||
gf->setTextCallback = [gf](const std::string &str) {
|
||||
if (!gf->text_source) {
|
||||
// attempt to acquire a weak ref to the text source if it's yet available
|
||||
acquire_weak_text_source_ref(gf);
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
|
||||
|
||||
obs_weak_source_t *text_source = gf->text_source;
|
||||
if (!text_source) {
|
||||
obs_log(LOG_ERROR, "text_source is null");
|
||||
return;
|
||||
}
|
||||
auto target = obs_weak_source_get_source(text_source);
|
||||
if (!target) {
|
||||
obs_log(LOG_ERROR, "text_source target is null");
|
||||
return;
|
||||
}
|
||||
auto text_settings = obs_source_get_settings(target);
|
||||
obs_data_set_string(text_settings, "text", str.c_str());
|
||||
obs_source_update(target, text_settings);
|
||||
obs_source_release(target);
|
||||
};
|
||||
|
||||
// get the settings updated on the filter data struct
|
||||
transcription_filter_update(gf, settings);
|
||||
|
||||
// start the thread
|
||||
gf->whisper_thread = std::thread(whisper_loop, gf);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
void transcription_filter_activate(void *data)
|
||||
{
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
obs_log(LOG_INFO, "transcription_filter filter activated");
|
||||
gf->active = true;
|
||||
}
|
||||
|
||||
void transcription_filter_deactivate(void *data)
|
||||
{
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
obs_log(LOG_INFO, "transcription_filter filter deactivated");
|
||||
gf->active = false;
|
||||
}
|
||||
|
||||
void transcription_filter_defaults(obs_data_t *s)
|
||||
{
|
||||
obs_data_set_default_bool(s, "vad_enabled", true);
|
||||
obs_data_set_default_int(s, "log_level", LOG_DEBUG);
|
||||
obs_data_set_default_bool(s, "log_words", true);
|
||||
obs_data_set_default_string(s, "whisper_model_path", "models/ggml-tiny.en.bin");
|
||||
obs_data_set_default_string(s, "whisper_language_select", "en");
|
||||
obs_data_set_default_string(s, "subtitle_sources", "none");
|
||||
|
||||
// Whisper parameters
|
||||
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
obs_data_set_default_string(s, "initial_prompt", "");
|
||||
obs_data_set_default_int(s, "n_threads", 4);
|
||||
obs_data_set_default_int(s, "n_max_text_ctx", 16384);
|
||||
obs_data_set_default_bool(s, "translate", false);
|
||||
obs_data_set_default_bool(s, "no_context", true);
|
||||
obs_data_set_default_bool(s, "single_segment", true);
|
||||
obs_data_set_default_bool(s, "print_special", false);
|
||||
obs_data_set_default_bool(s, "print_progress", false);
|
||||
obs_data_set_default_bool(s, "print_realtime", false);
|
||||
obs_data_set_default_bool(s, "print_timestamps", false);
|
||||
obs_data_set_default_bool(s, "token_timestamps", false);
|
||||
obs_data_set_default_double(s, "thold_pt", 0.01);
|
||||
obs_data_set_default_double(s, "thold_ptsum", 0.01);
|
||||
obs_data_set_default_int(s, "max_len", 0);
|
||||
obs_data_set_default_bool(s, "split_on_word", false);
|
||||
obs_data_set_default_int(s, "max_tokens", 32);
|
||||
obs_data_set_default_bool(s, "speed_up", false);
|
||||
obs_data_set_default_bool(s, "suppress_blank", false);
|
||||
obs_data_set_default_bool(s, "suppress_non_speech_tokens", true);
|
||||
obs_data_set_default_double(s, "temperature", 0.5);
|
||||
obs_data_set_default_double(s, "max_initial_ts", 1.0);
|
||||
obs_data_set_default_double(s, "length_penalty", -1.0);
|
||||
}
|
||||
|
||||
obs_properties_t *transcription_filter_properties(void *data)
|
||||
{
|
||||
obs_properties_t *ppts = obs_properties_create();
|
||||
|
||||
obs_properties_add_bool(ppts, "vad_enabled", "VAD Enabled");
|
||||
obs_property_t *list = obs_properties_add_list(ppts, "log_level", "Log level",
|
||||
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
|
||||
obs_property_list_add_int(list, "DEBUG", LOG_DEBUG);
|
||||
obs_property_list_add_int(list, "INFO", LOG_INFO);
|
||||
obs_property_list_add_int(list, "WARNING", LOG_WARNING);
|
||||
obs_properties_add_bool(ppts, "log_words", "Log output words");
|
||||
|
||||
obs_property_t *sources = obs_properties_add_list(ppts, "subtitle_sources",
|
||||
"subtitle_sources", OBS_COMBO_TYPE_LIST,
|
||||
OBS_COMBO_FORMAT_STRING);
|
||||
obs_enum_sources(add_sources_to_list, sources);
|
||||
|
||||
// Add a list of available whisper models to download
|
||||
obs_property_t *whisper_models_list =
|
||||
obs_properties_add_list(ppts, "whisper_model_path", "Whisper Model",
|
||||
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
|
||||
|
||||
obs_property_list_add_string(whisper_models_list, "Tiny (Eng) 75Mb",
|
||||
"models/ggml-tiny.en.bin");
|
||||
obs_property_list_add_string(whisper_models_list, "Tiny 75Mb", "models/ggml-tiny.bin");
|
||||
obs_property_list_add_string(whisper_models_list, "Base (Eng) 142Mb",
|
||||
"models/ggml-base.en.bin");
|
||||
obs_property_list_add_string(whisper_models_list, "Base 142Mb", "models/ggml-base.bin");
|
||||
obs_property_list_add_string(whisper_models_list, "Small (Eng) 466Mb",
|
||||
"models/ggml-small.en.bin");
|
||||
obs_property_list_add_string(whisper_models_list, "Small 466Mb", "models/ggml-small.bin");
|
||||
|
||||
obs_properties_t *whisper_params_group = obs_properties_create();
|
||||
obs_properties_add_group(ppts, "whisper_params_group", "Whisper Parameters",
|
||||
OBS_GROUP_NORMAL, whisper_params_group);
|
||||
|
||||
// Add language selector
|
||||
obs_property_t *whisper_language_select_list =
|
||||
obs_properties_add_list(whisper_params_group, "whisper_language_select", "Language",
|
||||
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
|
||||
// sort the languages by flipping the map
|
||||
std::map<std::string, std::string> whisper_available_lang_flip;
|
||||
for (auto const &pair : whisper_available_lang) {
|
||||
whisper_available_lang_flip[pair.second] = pair.first;
|
||||
}
|
||||
// iterate over all available languages and add them to the list
|
||||
for (auto const &pair : whisper_available_lang_flip) {
|
||||
// Capitalize the language name
|
||||
std::string language_name = pair.first;
|
||||
language_name[0] = (char)toupper(language_name[0]);
|
||||
|
||||
obs_property_list_add_string(whisper_language_select_list, language_name.c_str(),
|
||||
pair.second.c_str());
|
||||
}
|
||||
|
||||
obs_property_t *whisper_sampling_method_list = obs_properties_add_list(
|
||||
whisper_params_group, "whisper_sampling_method", "whisper_sampling_method",
|
||||
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
|
||||
obs_property_list_add_int(whisper_sampling_method_list, "Beam search",
|
||||
WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
obs_property_list_add_int(whisper_sampling_method_list, "Greedy", WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
// int n_threads;
|
||||
obs_properties_add_int_slider(whisper_params_group, "n_threads", "n_threads", 1, 8, 1);
|
||||
// int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
|
||||
obs_properties_add_int_slider(whisper_params_group, "n_max_text_ctx", "n_max_text_ctx", 0,
|
||||
16384, 100);
|
||||
// int offset_ms; // start offset in ms
|
||||
// int duration_ms; // audio duration to process in ms
|
||||
// bool translate;
|
||||
obs_properties_add_bool(whisper_params_group, "translate", "translate");
|
||||
// bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
obs_properties_add_bool(whisper_params_group, "no_context", "no_context");
|
||||
// bool single_segment; // force single segment output (useful for streaming)
|
||||
obs_properties_add_bool(whisper_params_group, "single_segment", "single_segment");
|
||||
// bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
obs_properties_add_bool(whisper_params_group, "print_special", "print_special");
|
||||
// bool print_progress; // print progress information
|
||||
obs_properties_add_bool(whisper_params_group, "print_progress", "print_progress");
|
||||
// bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
|
||||
obs_properties_add_bool(whisper_params_group, "print_realtime", "print_realtime");
|
||||
// bool print_timestamps; // print timestamps for each text segment when printing realtime
|
||||
obs_properties_add_bool(whisper_params_group, "print_timestamps", "print_timestamps");
|
||||
// bool token_timestamps; // enable token-level timestamps
|
||||
obs_properties_add_bool(whisper_params_group, "token_timestamps", "token_timestamps");
|
||||
// float thold_pt; // timestamp token probability threshold (~0.01)
|
||||
obs_properties_add_float_slider(whisper_params_group, "thold_pt", "thold_pt", 0.0f, 1.0f,
|
||||
0.05f);
|
||||
// float thold_ptsum; // timestamp token sum probability threshold (~0.01)
|
||||
obs_properties_add_float_slider(whisper_params_group, "thold_ptsum", "thold_ptsum", 0.0f,
|
||||
1.0f, 0.05f);
|
||||
// int max_len; // max segment length in characters
|
||||
obs_properties_add_int_slider(whisper_params_group, "max_len", "max_len", 0, 100, 1);
|
||||
// bool split_on_word; // split on word rather than on token (when used with max_len)
|
||||
obs_properties_add_bool(whisper_params_group, "split_on_word", "split_on_word");
|
||||
// int max_tokens; // max tokens per segment (0 = no limit)
|
||||
obs_properties_add_int_slider(whisper_params_group, "max_tokens", "max_tokens", 0, 100, 1);
|
||||
// bool speed_up; // speed-up the audio by 2x using Phase Vocoder
|
||||
obs_properties_add_bool(whisper_params_group, "speed_up", "speed_up");
|
||||
// const char * initial_prompt;
|
||||
obs_properties_add_text(whisper_params_group, "initial_prompt", "initial_prompt",
|
||||
OBS_TEXT_DEFAULT);
|
||||
// bool suppress_blank
|
||||
obs_properties_add_bool(whisper_params_group, "suppress_blank", "suppress_blank");
|
||||
// bool suppress_non_speech_tokens
|
||||
obs_properties_add_bool(whisper_params_group, "suppress_non_speech_tokens",
|
||||
"suppress_non_speech_tokens");
|
||||
// float temperature
|
||||
obs_properties_add_float_slider(whisper_params_group, "temperature", "temperature", 0.0f,
|
||||
1.0f, 0.05f);
|
||||
// float max_initial_ts
|
||||
obs_properties_add_float_slider(whisper_params_group, "max_initial_ts", "max_initial_ts",
|
||||
0.0f, 1.0f, 0.05f);
|
||||
// float length_penalty
|
||||
obs_properties_add_float_slider(whisper_params_group, "length_penalty", "length_penalty",
|
||||
-1.0f, 1.0f, 0.1f);
|
||||
|
||||
UNUSED_PARAMETER(data);
|
||||
return ppts;
|
||||
}
|
||||
19
src/transcription-filter.h
Normal file
19
src/transcription-filter.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#include <obs-module.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void transcription_filter_activate(void *data);
|
||||
void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter);
|
||||
void transcription_filter_update(void *data, obs_data_t *s);
|
||||
void transcription_filter_destroy(void *data);
|
||||
const char *transcription_filter_name(void *unused);
|
||||
struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_audio_data *audio);
|
||||
void transcription_filter_deactivate(void *data);
|
||||
void transcription_filter_defaults(obs_data_t *s);
|
||||
obs_properties_t *transcription_filter_properties(void *data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
410
src/whisper-language.h
Normal file
410
src/whisper-language.h
Normal file
@@ -0,0 +1,410 @@
|
||||
#ifndef WHISPER_LANGUAGE_H
|
||||
#define WHISPER_LANGUAGE_H
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
static const std::map<std::string, std::string> whisper_available_lang = {
|
||||
{
|
||||
"auto",
|
||||
"auto",
|
||||
},
|
||||
{
|
||||
"en",
|
||||
"english",
|
||||
},
|
||||
{
|
||||
"zh",
|
||||
"chinese",
|
||||
},
|
||||
{
|
||||
"de",
|
||||
"german",
|
||||
},
|
||||
{
|
||||
"es",
|
||||
"spanish",
|
||||
},
|
||||
{
|
||||
"ru",
|
||||
"russian",
|
||||
},
|
||||
{
|
||||
"ko",
|
||||
"korean",
|
||||
},
|
||||
{
|
||||
"fr",
|
||||
"french",
|
||||
},
|
||||
{
|
||||
"ja",
|
||||
"japanese",
|
||||
},
|
||||
{
|
||||
"pt",
|
||||
"portuguese",
|
||||
},
|
||||
{
|
||||
"tr",
|
||||
"turkish",
|
||||
},
|
||||
{
|
||||
"pl",
|
||||
"polish",
|
||||
},
|
||||
{
|
||||
"ca",
|
||||
"catalan",
|
||||
},
|
||||
{
|
||||
"nl",
|
||||
"dutch",
|
||||
},
|
||||
{
|
||||
"ar",
|
||||
"arabic",
|
||||
},
|
||||
{
|
||||
"sv",
|
||||
"swedish",
|
||||
},
|
||||
{
|
||||
"it",
|
||||
"italian",
|
||||
},
|
||||
{
|
||||
"id",
|
||||
"indonesian",
|
||||
},
|
||||
{
|
||||
"hi",
|
||||
"hindi",
|
||||
},
|
||||
{
|
||||
"fi",
|
||||
"finnish",
|
||||
},
|
||||
{
|
||||
"vi",
|
||||
"vietnamese",
|
||||
},
|
||||
{
|
||||
"he",
|
||||
"hebrew",
|
||||
},
|
||||
{
|
||||
"uk",
|
||||
"ukrainian",
|
||||
},
|
||||
{
|
||||
"el",
|
||||
"greek",
|
||||
},
|
||||
{
|
||||
"ms",
|
||||
"malay",
|
||||
},
|
||||
{
|
||||
"cs",
|
||||
"czech",
|
||||
},
|
||||
{
|
||||
"ro",
|
||||
"romanian",
|
||||
},
|
||||
{
|
||||
"da",
|
||||
"danish",
|
||||
},
|
||||
{
|
||||
"hu",
|
||||
"hungarian",
|
||||
},
|
||||
{
|
||||
"ta",
|
||||
"tamil",
|
||||
},
|
||||
{
|
||||
"no",
|
||||
"norwegian",
|
||||
},
|
||||
{
|
||||
"th",
|
||||
"thai",
|
||||
},
|
||||
{
|
||||
"ur",
|
||||
"urdu",
|
||||
},
|
||||
{
|
||||
"hr",
|
||||
"croatian",
|
||||
},
|
||||
{
|
||||
"bg",
|
||||
"bulgarian",
|
||||
},
|
||||
{
|
||||
"lt",
|
||||
"lithuanian",
|
||||
},
|
||||
{
|
||||
"la",
|
||||
"latin",
|
||||
},
|
||||
{
|
||||
"mi",
|
||||
"maori",
|
||||
},
|
||||
{
|
||||
"ml",
|
||||
"malayalam",
|
||||
},
|
||||
{
|
||||
"cy",
|
||||
"welsh",
|
||||
},
|
||||
{
|
||||
"sk",
|
||||
"slovak",
|
||||
},
|
||||
{
|
||||
"te",
|
||||
"telugu",
|
||||
},
|
||||
{
|
||||
"fa",
|
||||
"persian",
|
||||
},
|
||||
{
|
||||
"lv",
|
||||
"latvian",
|
||||
},
|
||||
{
|
||||
"bn",
|
||||
"bengali",
|
||||
},
|
||||
{
|
||||
"sr",
|
||||
"serbian",
|
||||
},
|
||||
{
|
||||
"az",
|
||||
"azerbaijani",
|
||||
},
|
||||
{
|
||||
"sl",
|
||||
"slovenian",
|
||||
},
|
||||
{
|
||||
"kn",
|
||||
"kannada",
|
||||
},
|
||||
{
|
||||
"et",
|
||||
"estonian",
|
||||
},
|
||||
{
|
||||
"mk",
|
||||
"macedonian",
|
||||
},
|
||||
{
|
||||
"br",
|
||||
"breton",
|
||||
},
|
||||
{
|
||||
"eu",
|
||||
"basque",
|
||||
},
|
||||
{
|
||||
"is",
|
||||
"icelandic",
|
||||
},
|
||||
{
|
||||
"hy",
|
||||
"armenian",
|
||||
},
|
||||
{
|
||||
"ne",
|
||||
"nepali",
|
||||
},
|
||||
{
|
||||
"mn",
|
||||
"mongolian",
|
||||
},
|
||||
{
|
||||
"bs",
|
||||
"bosnian",
|
||||
},
|
||||
{
|
||||
"kk",
|
||||
"kazakh",
|
||||
},
|
||||
{
|
||||
"sq",
|
||||
"albanian",
|
||||
},
|
||||
{
|
||||
"sw",
|
||||
"swahili",
|
||||
},
|
||||
{
|
||||
"gl",
|
||||
"galician",
|
||||
},
|
||||
{
|
||||
"mr",
|
||||
"marathi",
|
||||
},
|
||||
{
|
||||
"pa",
|
||||
"punjabi",
|
||||
},
|
||||
{
|
||||
"si",
|
||||
"sinhala",
|
||||
},
|
||||
{
|
||||
"km",
|
||||
"khmer",
|
||||
},
|
||||
{
|
||||
"sn",
|
||||
"shona",
|
||||
},
|
||||
{
|
||||
"yo",
|
||||
"yoruba",
|
||||
},
|
||||
{
|
||||
"so",
|
||||
"somali",
|
||||
},
|
||||
{
|
||||
"af",
|
||||
"afrikaans",
|
||||
},
|
||||
{
|
||||
"oc",
|
||||
"occitan",
|
||||
},
|
||||
{
|
||||
"ka",
|
||||
"georgian",
|
||||
},
|
||||
{
|
||||
"be",
|
||||
"belarusian",
|
||||
},
|
||||
{
|
||||
"tg",
|
||||
"tajik",
|
||||
},
|
||||
{
|
||||
"sd",
|
||||
"sindhi",
|
||||
},
|
||||
{
|
||||
"gu",
|
||||
"gujarati",
|
||||
},
|
||||
{
|
||||
"am",
|
||||
"amharic",
|
||||
},
|
||||
{
|
||||
"yi",
|
||||
"yiddish",
|
||||
},
|
||||
{
|
||||
"lo",
|
||||
"lao",
|
||||
},
|
||||
{
|
||||
"uz",
|
||||
"uzbek",
|
||||
},
|
||||
{
|
||||
"fo",
|
||||
"faroese",
|
||||
},
|
||||
{
|
||||
"ht",
|
||||
"haitian",
|
||||
},
|
||||
{
|
||||
"ps",
|
||||
"pashto",
|
||||
},
|
||||
{
|
||||
"tk",
|
||||
"turkmen",
|
||||
},
|
||||
{
|
||||
"nn",
|
||||
"nynorsk",
|
||||
},
|
||||
{
|
||||
"mt",
|
||||
"maltese",
|
||||
},
|
||||
{
|
||||
"sa",
|
||||
"sanskrit",
|
||||
},
|
||||
{
|
||||
"lb",
|
||||
"luxembourgish",
|
||||
},
|
||||
{
|
||||
"my",
|
||||
"myanmar",
|
||||
},
|
||||
{
|
||||
"bo",
|
||||
"tibetan",
|
||||
},
|
||||
{
|
||||
"tl",
|
||||
"tagalog",
|
||||
},
|
||||
{
|
||||
"mg",
|
||||
"malagasy",
|
||||
},
|
||||
{
|
||||
"as",
|
||||
"assamese",
|
||||
},
|
||||
{
|
||||
"tt",
|
||||
"tatar",
|
||||
},
|
||||
{
|
||||
"haw",
|
||||
"hawaiian",
|
||||
},
|
||||
{
|
||||
"ln",
|
||||
"lingala",
|
||||
},
|
||||
{
|
||||
"ha",
|
||||
"hausa",
|
||||
},
|
||||
{
|
||||
"ba",
|
||||
"bashkir",
|
||||
},
|
||||
{
|
||||
"jw",
|
||||
"javanese",
|
||||
},
|
||||
{
|
||||
"su",
|
||||
"sundanese",
|
||||
},
|
||||
};
|
||||
|
||||
#endif // WHISPER_LANGUAGE_H
|
||||
345
src/whisper-processing.cpp
Normal file
345
src/whisper-processing.cpp
Normal file
@@ -0,0 +1,345 @@
|
||||
#include <whisper.h>
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
#include "plugin-support.h"
|
||||
#include "transcription-filter-data.h"
|
||||
#include "whisper-processing.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
|
||||
#define VAD_THOLD 0.0001f
|
||||
#define FREQ_THOLD 100.0f
|
||||
|
||||
// Taken from https://github.com/ggerganov/whisper.cpp/blob/master/examples/stream/stream.cpp
|
||||
std::string to_timestamp(int64_t t)
|
||||
{
|
||||
int64_t sec = t / 100;
|
||||
int64_t msec = t - sec * 100;
|
||||
int64_t min = sec / 60;
|
||||
sec = sec - min * 60;
|
||||
|
||||
char buf[32];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int)min, (int)sec, (int)msec);
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
void high_pass_filter(float *pcmf32, size_t pcm32f_size, float cutoff, uint32_t sample_rate)
|
||||
{
|
||||
const float rc = 1.0f / (2.0f * (float)M_PI * cutoff);
|
||||
const float dt = 1.0f / (float)sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = pcmf32[0];
|
||||
|
||||
for (size_t i = 1; i < pcm32f_size; i++) {
|
||||
y = alpha * (y + pcmf32[i] - pcmf32[i - 1]);
|
||||
pcmf32[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
// VAD (voice activity detection), return true if speech detected
|
||||
bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float vad_thold,
|
||||
float freq_thold, bool verbose)
|
||||
{
|
||||
const uint64_t n_samples = pcm32f_size;
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, pcm32f_size, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
|
||||
for (uint64_t i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
}
|
||||
|
||||
energy_all /= (float)n_samples;
|
||||
|
||||
if (verbose) {
|
||||
blog(LOG_INFO, "%s: energy_all: %f, vad_thold: %f, freq_thold: %f", __func__,
|
||||
energy_all, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_all < vad_thold) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct whisper_context *init_whisper_context(const std::string &model_path)
|
||||
{
|
||||
struct whisper_context *ctx = whisper_init_from_file(obs_module_file(model_path.c_str()));
|
||||
if (ctx == nullptr) {
|
||||
obs_log(LOG_ERROR, "Failed to load whisper model");
|
||||
return nullptr;
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
enum DetectionResult {
|
||||
DETECTION_RESULT_UNKNOWN = 0,
|
||||
DETECTION_RESULT_SILENCE = 1,
|
||||
DETECTION_RESULT_SPEECH = 2,
|
||||
};
|
||||
|
||||
struct DetectionResultWithText {
|
||||
DetectionResult result;
|
||||
std::string text;
|
||||
};
|
||||
|
||||
struct DetectionResultWithText run_whisper_inference(struct transcription_filter_data *gf,
|
||||
const float *pcm32f_data, size_t pcm32f_size)
|
||||
{
|
||||
obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__,
|
||||
int(pcm32f_size), float(pcm32f_size) / WHISPER_SAMPLE_RATE,
|
||||
gf->whisper_params.n_threads);
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_WARNING, "whisper context is null");
|
||||
return {DETECTION_RESULT_UNKNOWN, ""};
|
||||
}
|
||||
|
||||
// run the inference
|
||||
int whisper_full_result = -1;
|
||||
try {
|
||||
whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params,
|
||||
pcm32f_data, (int)pcm32f_size);
|
||||
} catch (const std::exception &e) {
|
||||
obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what());
|
||||
whisper_free(gf->whisper_context);
|
||||
gf->whisper_context = nullptr;
|
||||
return {DETECTION_RESULT_UNKNOWN, ""};
|
||||
}
|
||||
|
||||
if (whisper_full_result != 0) {
|
||||
obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result);
|
||||
return {DETECTION_RESULT_UNKNOWN, ""};
|
||||
} else {
|
||||
const int n_segment = 0;
|
||||
const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(gf->whisper_context, n_segment);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(gf->whisper_context, n_segment);
|
||||
|
||||
float sentence_p = 0.0f;
|
||||
const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
sentence_p += whisper_full_get_token_p(gf->whisper_context, n_segment, j);
|
||||
}
|
||||
sentence_p /= (float)n_tokens;
|
||||
|
||||
// convert text to lowercase
|
||||
std::string text_lower(text);
|
||||
std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower);
|
||||
// trim whitespace (use lambda)
|
||||
text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(),
|
||||
[](unsigned char ch) { return !std::isspace(ch); })
|
||||
.base(),
|
||||
text_lower.end());
|
||||
|
||||
if (gf->log_words) {
|
||||
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
|
||||
to_timestamp(t1).c_str(), sentence_p, text_lower.c_str());
|
||||
}
|
||||
|
||||
if (text_lower.empty()) {
|
||||
return {DETECTION_RESULT_SILENCE, ""};
|
||||
}
|
||||
|
||||
return {DETECTION_RESULT_SPEECH, text_lower};
|
||||
}
|
||||
}
|
||||
|
||||
void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
{
|
||||
uint32_t num_new_frames_from_infos = 0;
|
||||
uint64_t start_timestamp = 0;
|
||||
|
||||
{
|
||||
// scoped lock the buffer mutex
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
|
||||
|
||||
// We need (gf->frames - gf->overlap_frames) new frames to run inference,
|
||||
// except for the first segment, where we need the whole gf->frames frames
|
||||
size_t how_many_frames_needed = gf->frames - gf->overlap_frames;
|
||||
if (gf->last_num_frames == 0) {
|
||||
how_many_frames_needed = gf->frames;
|
||||
}
|
||||
|
||||
// pop infos from the info buffer and mark the beginning timestamp from the first
|
||||
// info as the beginning timestamp of the segment
|
||||
struct transcription_filter_audio_info info_from_buf = {0};
|
||||
while (gf->info_buffer.size >= sizeof(struct transcription_filter_audio_info)) {
|
||||
circlebuf_pop_front(&gf->info_buffer, &info_from_buf,
|
||||
sizeof(struct transcription_filter_audio_info));
|
||||
num_new_frames_from_infos += info_from_buf.frames;
|
||||
if (start_timestamp == 0) {
|
||||
start_timestamp = info_from_buf.timestamp;
|
||||
}
|
||||
obs_log(gf->log_level, "popped %d frames from info buffer, %lu needed",
|
||||
num_new_frames_from_infos, how_many_frames_needed);
|
||||
// Check if we're within the needed segment length
|
||||
if (num_new_frames_from_infos > how_many_frames_needed) {
|
||||
// too big, push the last info into the buffer's front where it was
|
||||
num_new_frames_from_infos -= info_from_buf.frames;
|
||||
circlebuf_push_front(
|
||||
&gf->info_buffer, &info_from_buf,
|
||||
sizeof(struct transcription_filter_audio_info));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/* Pop from input circlebuf */
|
||||
for (size_t c = 0; c < gf->channels; c++) {
|
||||
if (gf->last_num_frames > 0) {
|
||||
// move overlap frames from the end of the last copy_buffers to the beginning
|
||||
memcpy(gf->copy_buffers[c],
|
||||
gf->copy_buffers[c] + gf->last_num_frames -
|
||||
gf->overlap_frames,
|
||||
gf->overlap_frames * sizeof(float));
|
||||
// copy new data to the end of copy_buffers[c]
|
||||
circlebuf_pop_front(&gf->input_buffers[c],
|
||||
gf->copy_buffers[c] + gf->overlap_frames,
|
||||
num_new_frames_from_infos * sizeof(float));
|
||||
} else {
|
||||
// Very first time, just copy data to copy_buffers[c]
|
||||
circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c],
|
||||
num_new_frames_from_infos * sizeof(float));
|
||||
}
|
||||
}
|
||||
obs_log(gf->log_level,
|
||||
"popped %u frames from input buffer. input_buffer[0] size is %lu",
|
||||
num_new_frames_from_infos, gf->input_buffers[0].size);
|
||||
|
||||
if (gf->last_num_frames > 0) {
|
||||
gf->last_num_frames = num_new_frames_from_infos + gf->overlap_frames;
|
||||
} else {
|
||||
gf->last_num_frames = num_new_frames_from_infos;
|
||||
}
|
||||
}
|
||||
|
||||
obs_log(gf->log_level, "processing %d frames (%d ms), start timestamp %llu ",
|
||||
(int)gf->last_num_frames, (int)(gf->last_num_frames * 1000 / gf->sample_rate),
|
||||
start_timestamp);
|
||||
|
||||
// time the audio processing
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// resample to 16kHz
|
||||
float *output[MAX_PREPROC_CHANNELS];
|
||||
uint32_t out_frames;
|
||||
uint64_t ts_offset;
|
||||
audio_resampler_resample(gf->resampler, (uint8_t **)output, &out_frames, &ts_offset,
|
||||
(const uint8_t **)gf->copy_buffers, (uint32_t)gf->last_num_frames);
|
||||
|
||||
obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels, (int)out_frames,
|
||||
(float)out_frames / WHISPER_SAMPLE_RATE * 1000.0f);
|
||||
|
||||
bool skipped_inference = false;
|
||||
|
||||
if (gf->vad_enabled) {
|
||||
skipped_inference = !::vad_simple(output[0], out_frames, WHISPER_SAMPLE_RATE,
|
||||
VAD_THOLD, FREQ_THOLD,
|
||||
gf->log_level != LOG_DEBUG);
|
||||
}
|
||||
|
||||
if (!skipped_inference) {
|
||||
// run inference
|
||||
const struct DetectionResultWithText inference_result =
|
||||
run_whisper_inference(gf, output[0], out_frames);
|
||||
|
||||
if (inference_result.result == DETECTION_RESULT_SPEECH) {
|
||||
// output inference result to a text source
|
||||
gf->setTextCallback(inference_result.text);
|
||||
} else if (inference_result.result == DETECTION_RESULT_SILENCE) {
|
||||
// output inference result to a text source
|
||||
gf->setTextCallback("[silence]");
|
||||
}
|
||||
} else {
|
||||
if (gf->log_words) {
|
||||
obs_log(LOG_INFO, "skipping inference");
|
||||
}
|
||||
gf->setTextCallback("");
|
||||
}
|
||||
|
||||
// end of timer
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
|
||||
const uint32_t new_frames_from_infos_ms =
|
||||
num_new_frames_from_infos * 1000 /
|
||||
gf->sample_rate; // number of frames in this packet
|
||||
obs_log(gf->log_level, "audio processing of %u ms new data took %d ms",
|
||||
new_frames_from_infos_ms, (int)duration);
|
||||
|
||||
if (duration > new_frames_from_infos_ms) {
|
||||
// try to decrease overlap down to minimum of 100 ms
|
||||
gf->overlap_ms = std::max((uint64_t)gf->overlap_ms - 10, (uint64_t)100);
|
||||
gf->overlap_frames = gf->overlap_ms * gf->sample_rate / 1000;
|
||||
obs_log(gf->log_level,
|
||||
"audio processing took too long (%d ms), reducing overlap to %lu ms",
|
||||
(int)duration, gf->overlap_ms);
|
||||
} else if (!skipped_inference) {
|
||||
if (gf->overlap_ms < OVERLAP_SIZE_MSEC) {
|
||||
// try to increase overlap up to OVERLAP_SIZE_MSEC
|
||||
gf->overlap_ms = std::min((uint64_t)gf->overlap_ms + 10,
|
||||
(uint64_t)OVERLAP_SIZE_MSEC);
|
||||
gf->overlap_frames = gf->overlap_ms * gf->sample_rate / 1000;
|
||||
obs_log(gf->log_level,
|
||||
"audio processing took %d ms, increasing overlap to %lu ms",
|
||||
(int)duration, gf->overlap_ms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void whisper_loop(void *data)
|
||||
{
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
const size_t segment_size = gf->frames * sizeof(float);
|
||||
|
||||
obs_log(LOG_INFO, "starting whisper thread");
|
||||
|
||||
// Thread main loop
|
||||
while (true) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_WARNING, "Whisper context is null, exiting thread");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have enough data to process
|
||||
while (true) {
|
||||
size_t input_buf_size = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
|
||||
input_buf_size = gf->input_buffers[0].size;
|
||||
}
|
||||
|
||||
if (input_buf_size >= segment_size) {
|
||||
obs_log(gf->log_level,
|
||||
"found %lu bytes, %lu frames in input buffer, need >= %lu, processing",
|
||||
input_buf_size, (size_t)(input_buf_size / sizeof(float)),
|
||||
segment_size);
|
||||
|
||||
// Process the audio. This will also remove the processed data from the input buffer.
|
||||
// Mutex is locked inside process_audio_from_buffer.
|
||||
process_audio_from_buffer(gf);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Sleep for 10 ms using the condition variable wshiper_thread_cv
|
||||
// This will wake up the thread if there is new data in the input buffer
|
||||
// or if the whisper context is null
|
||||
std::unique_lock<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
gf->wshiper_thread_cv->wait_for(lock, std::chrono::milliseconds(10));
|
||||
}
|
||||
|
||||
obs_log(LOG_INFO, "exiting whisper thread");
|
||||
}
|
||||
14
src/whisper-processing.h
Normal file
14
src/whisper-processing.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef WHISPER_PROCESSING_H
|
||||
#define WHISPER_PROCESSING_H
|
||||
|
||||
// buffer size in msec
|
||||
#define BUFFER_SIZE_MSEC 3000
|
||||
// at 16Khz, 3000 msec is 48000 samples
|
||||
#define WHISPER_FRAME_SIZE 48000
|
||||
// overlap in msec
|
||||
#define OVERLAP_SIZE_MSEC 200
|
||||
|
||||
void whisper_loop(void *data);
|
||||
struct whisper_context *init_whisper_context(const std::string &model_path);
|
||||
|
||||
#endif // WHISPER_PROCESSING_H
|
||||
1
vendor/curl
vendored
Submodule
1
vendor/curl
vendored
Submodule
Submodule vendor/curl added at 439ff2052e
Reference in New Issue
Block a user