refactor: Update vad_threshold default value to 0.95

This commit is contained in:
Roy Shilkrot
2024-07-18 02:15:02 -04:00
parent 71beb03231
commit b009bd0c37
3 changed files with 93 additions and 86 deletions

View File

@@ -190,6 +190,7 @@ void set_text_callback(struct transcription_filter_data *gf,
{
DetectionResultWithText result = resultIn;
if (!result.text.empty() && result.result == DETECTION_RESULT_SPEECH) {
// this sub should be rendered - update the last sub render time
gf->last_sub_render_time = now_ms();
gf->cleared_last_sub = false;
}
@@ -222,19 +223,22 @@ void set_text_callback(struct transcription_filter_data *gf,
}
}
// time the translation
uint64_t start_time = now_ms();
// send the sentence to translation (if enabled)
std::string translated_sentence =
send_sentence_to_translation(str_copy, gf, result.language);
// Timed metadata request
if (!gf->translate) {
send_timed_metadata_to_server(gf, TRANSCRIBE, str_copy, result.language, "", "");
} else {
if (gf->translate) {
// log the translation time
obs_log(gf->log_level, "Translation time: %llu ms", now_ms() - start_time);
// send the translated sentence to the server
send_timed_metadata_to_server(gf, NON_WHISPER_TRANSLATE, str_copy, result.language,
translated_sentence, gf->target_lang);
}
if (gf->translate) {
// send the translated sentence to the selected output
if (gf->translation_output == "none") {
// overwrite the original text with the translated text
str_copy = translated_sentence;
@@ -247,6 +251,8 @@ void set_text_callback(struct transcription_filter_data *gf,
gf);
}
}
} else {
send_timed_metadata_to_server(gf, TRANSCRIBE, str_copy, result.language, "", "");
}
if (gf->buffered_output) {

View File

@@ -571,7 +571,7 @@ void transcription_filter_defaults(obs_data_t *s)
(int)TokenBufferSegmentation::SEGMENTATION_WORD);
obs_data_set_default_bool(s, "vad_enabled", true);
obs_data_set_default_double(s, "vad_threshold", 0.65);
obs_data_set_default_double(s, "vad_threshold", 0.95);
obs_data_set_default_int(s, "log_level", LOG_DEBUG);
obs_data_set_default_bool(s, "log_words", false);
obs_data_set_default_bool(s, "caption_to_stream", false);

View File

@@ -172,6 +172,9 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""};
}
// time the operation
auto start = std::chrono::high_resolution_clock::now();
// run the inference
int whisper_full_result = -1;
gf->whisper_params.duration_ms = (int)(duration_ms);
@@ -191,6 +194,11 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
bfree(pcm32f_data);
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
obs_log(gf->log_level, "Transcription time: %d ms for %d ms of audio",
(int)duration.count(), (int)duration_ms);
std::string language = gf->whisper_params.language;
if (gf->whisper_params.language == nullptr || strlen(gf->whisper_params.language) == 0 ||
strcmp(gf->whisper_params.language, "auto") == 0) {
@@ -202,86 +210,79 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
if (whisper_full_result != 0) {
obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result);
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""};
} else {
float sentence_p = 0.0f;
std::string text = "";
std::string tokenIds = "";
std::vector<whisper_token_data> tokens;
for (int n_segment = 0; n_segment < whisper_full_n_segments(gf->whisper_context);
++n_segment) {
const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment);
for (int j = 0; j < n_tokens; ++j) {
// get token
whisper_token_data token = whisper_full_get_token_data(
gf->whisper_context, n_segment, j);
const char *token_str =
whisper_token_to_str(gf->whisper_context, token.id);
bool keep = true;
// if the token starts with '[' and ends with ']', don't keep it
if (token_str[0] == '[' &&
token_str[strlen(token_str) - 1] == ']') {
keep = false;
}
// if this is a special token, don't keep it
if (token.id >= 50256) {
keep = false;
}
// if the second to last token is .id == 13 ('.'), don't keep it
if (j == n_tokens - 2 && token.id == 13) {
keep = false;
}
// token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json
if (token.id > 50365 && token.id <= 51865) {
const float time = ((float)token.id - 50365.0f) * 0.02f;
const float duration_s = (float)duration_ms / 1000.0f;
const float ratio = std::max(time, duration_s) /
std::min(time, duration_s);
obs_log(gf->log_level,
"Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f.",
token.id, time, duration_s, ratio);
if (ratio > 3.0f) {
// ratio is too high, skip this detection
obs_log(gf->log_level,
"Time token ratio too high, skipping");
return {DETECTION_RESULT_SILENCE,
"",
t0,
t1,
{},
language};
}
keep = false;
}
if (keep) {
sentence_p += token.p;
text += token_str;
tokens.push_back(token);
}
obs_log(gf->log_level, "S %d, Token %d: %d\t%s\tp: %.3f [keep: %d]",
n_segment, j, token.id, token_str, token.p, keep);
}
}
sentence_p /= (float)tokens.size();
if (sentence_p < gf->sentence_psum_accept_thresh) {
obs_log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping",
sentence_p, gf->sentence_psum_accept_thresh);
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language};
}
obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str());
if (gf->log_words) {
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
to_timestamp(t1).c_str(), sentence_p, text.c_str());
}
if (text.empty() || text == "." || text == " " || text == "\n") {
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language};
}
return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens, language};
}
float sentence_p = 0.0f;
std::string text = "";
std::string tokenIds = "";
std::vector<whisper_token_data> tokens;
for (int n_segment = 0; n_segment < whisper_full_n_segments(gf->whisper_context);
++n_segment) {
const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment);
for (int j = 0; j < n_tokens; ++j) {
// get token
whisper_token_data token =
whisper_full_get_token_data(gf->whisper_context, n_segment, j);
const char *token_str = whisper_token_to_str(gf->whisper_context, token.id);
bool keep = true;
// if the token starts with '[' and ends with ']', don't keep it
if (token_str[0] == '[' && token_str[strlen(token_str) - 1] == ']') {
keep = false;
}
// if this is a special token, don't keep it
if (token.id >= 50256) {
keep = false;
}
// if the second to last token is .id == 13 ('.'), don't keep it
if (j == n_tokens - 2 && token.id == 13) {
keep = false;
}
// token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json
if (token.id > 50365 && token.id <= 51865) {
const float time = ((float)token.id - 50365.0f) * 0.02f;
const float duration_s = (float)duration_ms / 1000.0f;
const float ratio =
std::max(time, duration_s) / std::min(time, duration_s);
obs_log(gf->log_level,
"Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f.",
token.id, time, duration_s, ratio);
if (ratio > 3.0f) {
// ratio is too high, skip this detection
obs_log(gf->log_level,
"Time token ratio too high, skipping");
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language};
}
keep = false;
}
if (keep) {
sentence_p += token.p;
text += token_str;
tokens.push_back(token);
}
obs_log(gf->log_level, "S %d, Token %d: %d\t%s\tp: %.3f [keep: %d]",
n_segment, j, token.id, token_str, token.p, keep);
}
}
sentence_p /= (float)tokens.size();
if (sentence_p < gf->sentence_psum_accept_thresh) {
obs_log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping",
sentence_p, gf->sentence_psum_accept_thresh);
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language};
}
obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str());
if (gf->log_words) {
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
to_timestamp(t1).c_str(), sentence_p, text.c_str());
}
if (text.empty() || text == "." || text == " " || text == "\n") {
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language};
}
return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens, language};
}
void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms,