Use Silero VAD in Batched Mode (#936)

Replace Pyannote VAD with Silero to reduce code duplication and requirements
This commit is contained in:
Mahmoud Ashraf
2024-10-24 12:05:25 +03:00
committed by GitHub
parent 574e2563e7
commit 2dbca5e559
12 changed files with 278 additions and 509 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -7,6 +7,7 @@ import zlib
from collections import Counter, defaultdict
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
import ctranslate2
@@ -14,26 +15,18 @@ import numpy as np
import tokenizers
import torch
from pyannote.audio import Model
from tqdm import tqdm
from faster_whisper.audio import decode_audio, pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
from faster_whisper.utils import (
download_model,
format_timestamp,
get_assets_path,
get_end,
get_logger,
)
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
VadOptions,
VoiceActivitySegmentation,
collect_chunks,
get_speech_timestamps,
merge_chunks,
merge_segments,
)
@@ -115,67 +108,26 @@ class BatchedInferencePipeline:
def __init__(
self,
model,
use_vad_model: bool = True,
options: Optional[NamedTuple] = None,
tokenizer=None,
chunk_length: int = 30,
vad_device: Union[int, str, "torch.device"] = "auto",
vad_onset: float = 0.500,
vad_offset: float = 0.363,
language: Optional[str] = None,
):
self.model: WhisperModel = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.use_vad_model = use_vad_model
self.vad_onset = vad_onset
self.vad_offset = vad_offset
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
if self.use_vad_model:
self.vad_device = self.get_device(vad_device)
self.vad_model = self.load_vad_model(
vad_onset=self.vad_onset, vad_offset=self.vad_offset
)
else:
self.vad_model = None
self.chunk_length = chunk_length # VAD merging size
self.last_speech_timestamp = 0.0
def get_device(self, device: Union[int, str, "torch.device"]):
"""
Converts the input device into a torch.device object.
The input can be an integer, a string, or a `torch.device` object.
The function handles a special case where the input device is "auto".
When "auto" is specified, the device will default to the
device of the model (self.model.device). If the model's device is also "auto",
it selects "cuda" if a CUDA-capable device is available; otherwise, it selects "cpu".
"""
if isinstance(device, torch.device):
return device
elif isinstance(device, str):
if device == "auto" and self.model.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "auto":
device = self.model.device
return torch.device(device)
elif device < 0:
return torch.device("cpu")
else:
return torch.device(f"cuda:{device}")
def forward(self, features, segments_metadata, **forward_params):
def forward(self, features, chunks_metadata, **forward_params):
encoder_output, outputs = self.model.generate_segment_batched(
features, self.tokenizer, forward_params
)
segmented_outputs = []
segment_sizes = []
for segment_metadata, output in zip(segments_metadata, outputs):
duration = segment_metadata["end_time"] - segment_metadata["start_time"]
segment_size = int(duration * self.model.frames_per_second)
for chunk_metadata, output in zip(chunks_metadata, outputs):
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
segment_size = int(ceil(duration) * self.model.frames_per_second)
segment_sizes.append(segment_size)
(
subsegments,
@@ -184,7 +136,7 @@ class BatchedInferencePipeline:
) = self.model._split_segments_by_timestamps(
tokenizer=self.tokenizer,
tokens=output["tokens"],
time_offset=segment_metadata["start_time"],
time_offset=chunk_metadata["start_time"],
segment_size=segment_size,
segment_duration=duration,
seek=0,
@@ -252,43 +204,9 @@ class BatchedInferencePipeline:
return language, language_probability, task, all_language_probs
@staticmethod
def audio_split(audio, segments, sampling_rate):
"""Returns splitted audio chunks as iterator"""
audio_segments = []
segments_metadata = []
for seg in segments:
f1 = int(seg["start"] * sampling_rate)
f2 = int(seg["end"] * sampling_rate)
seg_metadata = {
"start_time": seg["start"],
"end_time": seg["end"],
"stitched_seg": seg["segments"],
}
audio_segments.append(audio[f1:f2])
segments_metadata.append(seg_metadata)
return audio_segments, segments_metadata
def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
vad_model = Model.from_pretrained(self.vad_model_path)
hyperparameters = {
"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1,
}
vad_pipeline = VoiceActivitySegmentation(
segmentation=vad_model, device=torch.device(self.vad_device)
)
vad_pipeline.instantiate(hyperparameters)
return vad_pipeline
def transcribe(
self,
audio: Union[str, torch.Tensor, np.ndarray],
vad_segments: Optional[List[dict]] = None,
batch_size: int = 16,
audio: Union[str, BinaryIO, torch.Tensor, np.ndarray],
language: Optional[str] = None,
task: str = None,
log_progress: bool = False,
@@ -314,26 +232,26 @@ class BatchedInferencePipeline:
prefix: Optional[str] = None,
suppress_blank: bool = True,
suppress_tokens: Optional[List[int]] = [-1],
without_timestamps: bool = True,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16,
hotwords: Optional[str] = None,
word_timestamps: bool = False,
without_timestamps: bool = True,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info.
Arguments:
audio: audio file as numpy array/path for batched transcription.
vad_segments: Optionally provide list of dictionaries each containing "start", "end",
and "segments" keys.
"start" and "end" keys specify the start and end of the voiced region within
30 sec boundary. An additional key "segments" contains all the start
and end of voiced regions within that 30sec boundary as a list of tuples.
If no vad_segments specified, it uses internal vad model automatically segment them.
batch_size: the maximum number of parallel requests to model for decoding.
language: The language spoken in the audio.
task: either "transcribe" or "translate".
audio: Path to the input file (or a file-like object), or the audio waveform.
language: The language spoken in the audio. It should be a language code such
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
of audio.
task: Task to execute (transcribe or translate).
log_progress: whether to show progress bar or not.
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
@@ -350,8 +268,8 @@ class BatchedInferencePipeline:
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
@@ -361,18 +279,29 @@ class BatchedInferencePipeline:
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in `tokenizer.non_speech_tokens()`.
without_timestamps: Only sample text tokens.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
Set as False.
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
the maximum will be set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
clip_timestamps: Optionally provide list of dictionaries each containing "start" and
"end" keys that specify the start and end of the voiced region within
`chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
batch_size: the maximum number of parallel requests to model for decoding.
hotwords:
Hotwords/hint phrases to the model. Has no effect if prefix is not None.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
Set as False.
without_timestamps: Only sample text tokens.
Static params: (Fixed for batched version)
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
@@ -390,28 +319,18 @@ class BatchedInferencePipeline:
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
clip_timestamps:
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
process. The last end timestamp defaults to the end of the file. Set as "0".
unused:
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
Returns:
A tuple with:
- a generator over transcribed batched segments.
- an instance of TranscriptionInfo.
- a generator over transcribed segments
- an instance of TranscriptionInfo
"""
sampling_rate = self.model.feature_extractor.sampling_rate
@@ -422,29 +341,32 @@ class BatchedInferencePipeline:
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
chunk_length = chunk_length or self.model.feature_extractor.chunk_length
# if no segment split is provided, use vad_model and generate segments
if not vad_segments:
# run the audio if it is less than 30 sec even without vad_segments
if self.use_vad_model:
vad_segments = self.vad_model(
{
"waveform": audio.unsqueeze(0),
"sample_rate": 16000,
}
)
vad_segments = merge_chunks(
vad_segments,
self.chunk_length,
onset=self.vad_onset,
offset=self.vad_offset,
)
elif duration < self.chunk_length:
vad_segments = [
{"start": 0.0, "end": duration, "segments": [(0.0, duration)]}
]
if not clip_timestamps:
if vad_filter:
if vad_parameters is None:
vad_parameters = VadOptions(
max_speech_duration_s=chunk_length,
min_silence_duration_ms=160,
)
elif isinstance(vad_parameters, dict):
if "max_speech_duration_s" in vad_parameters.keys():
vad_parameters.pop("max_speech_duration_s")
vad_parameters = VadOptions(
**vad_parameters, max_speech_duration_s=chunk_length
)
active_segments = get_speech_timestamps(audio, vad_parameters)
clip_timestamps = merge_segments(active_segments, vad_parameters)
# run the audio if it is less than 30 sec even without clip_timestamps
elif duration < chunk_length:
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
else:
raise RuntimeError(
"No vad segments found. Set 'use_vad_model' to True while loading the model"
"No clip timestamps found. "
"Set 'vad_filter' to True or provide 'clip_timestamps'."
)
if self.model.model.is_multilingual:
language = language or self.preset_language
@@ -452,7 +374,7 @@ class BatchedInferencePipeline:
if language is not None:
self.model.logger.warning(
f"English-only model is used, but {language} language is"
"chosen, setting language to 'en'."
" chosen, setting language to 'en'."
)
language = "en"
@@ -463,8 +385,9 @@ class BatchedInferencePipeline:
all_language_probs,
) = self.get_language_and_tokenizer(audio, task, language)
duration_after_vad = sum(
segment["end"] - segment["start"] for segment in vad_segments
duration_after_vad = (
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
/ sampling_rate
)
# batched options: see the difference with default options in WhisperModel
@@ -511,27 +434,26 @@ class BatchedInferencePipeline:
all_language_probs=all_language_probs,
)
audio_segments, segments_metadata = self.audio_split(
audio, vad_segments, sampling_rate
)
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
to_cpu = (
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
)
audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor(
padding=0
)
features = torch.stack(
[
self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[
..., : self.model.feature_extractor.nb_max_frames
features = (
torch.stack(
[
self.model.feature_extractor(chunk, to_cpu=to_cpu)[
..., : self.model.feature_extractor.nb_max_frames
]
for chunk in audio_chunks
]
for audio_segment in audio_segments
]
)
if duration_after_vad
else []
)
segments = self._batched_segments_generator(
features,
segments_metadata,
chunks_metadata,
batch_size,
batched_options,
log_progress,
@@ -540,14 +462,14 @@ class BatchedInferencePipeline:
return segments, info
def _batched_segments_generator(
self, features, segments_metadata, batch_size, options, log_progress
self, features, chunks_metadata, batch_size, options, log_progress
):
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
seg_idx = 0
for i in range(0, len(features), batch_size):
results = self.forward(
features[i : i + batch_size],
segments_metadata[i : i + batch_size],
chunks_metadata[i : i + batch_size],
**options._asdict(),
)
@@ -850,7 +772,8 @@ class WhisperModel:
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio = torch.cat(audio_chunks, dim=0)
duration_after_vad = audio.shape[0] / sampling_rate
self.logger.info(
@@ -1905,7 +1828,8 @@ class WhisperModel:
# get chunks of audio that contain speech
speech_chunks = get_speech_timestamps(audio, vad_params)
# merge chunks of audio that contain speech into a single array
audio = collect_chunks(audio, speech_chunks)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio = torch.cat(audio_chunks, dim=0)
# calculate new duration of audio without silence
duration_vad = audio.shape[0] / sampling_rate

View File

@@ -2,18 +2,11 @@ import bisect
import functools
import os
from abc import ABC
from collections.abc import Callable
from typing import List, NamedTuple, Optional, Union
from typing import Dict, List, NamedTuple, Optional, Tuple
import numpy as np
import torch
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from faster_whisper.utils import get_assets_path
@@ -22,9 +15,14 @@ class VadOptions(NamedTuple):
"""VAD options.
Attributes:
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
onset: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
offset: Silence threshold for determining the end of speech. If a probability is lower than
the offset, it is always considered silence. Values higher than offset are only considered
speech if the previous sample was classified as speech; otherwise, they are treated as
silence. This parameter helps refine the detection of speech transitions, ensuring smoother
segment boundaries.
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
than max_speech_duration_s will be split at the timestamp of the last silence that
@@ -35,8 +33,9 @@ class VadOptions(NamedTuple):
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
"""
threshold: float = 0.5
min_speech_duration_ms: int = 250
onset: float = 0.5
offset: float = onset - 0.15
min_speech_duration_ms: int = 0
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
speech_pad_ms: int = 400
@@ -45,6 +44,7 @@ class VadOptions(NamedTuple):
def get_speech_timestamps(
audio: torch.Tensor,
vad_options: Optional[VadOptions] = None,
sampling_rate: int = 16000,
**kwargs,
) -> List[dict]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
@@ -52,6 +52,7 @@ def get_speech_timestamps(
Args:
audio: One dimensional float array.
vad_options: Options for VAD processing.
sampling rate: Sampling rate of the audio.
kwargs: VAD options passed as keyword arguments for backward compatibility.
Returns:
@@ -60,13 +61,12 @@ def get_speech_timestamps(
if vad_options is None:
vad_options = VadOptions(**kwargs)
threshold = vad_options.threshold
onset = vad_options.onset
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = 512
speech_pad_ms = vad_options.speech_pad_ms
sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = (
@@ -80,20 +80,16 @@ def get_speech_timestamps(
audio_length_samples = len(audio)
model = get_vad_model()
state, context = model.get_initial_states(batch_size=1)
speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state, context = model(chunk, state, context, sampling_rate)
speech_probs.append(speech_prob)
padded_audio = np.pad(
audio.numpy(), (0, window_size_samples - audio.shape[0] % window_size_samples)
)
speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0)
triggered = False
speeches = []
current_speech = {}
neg_threshold = threshold - 0.15
offset = vad_options.offset
# to save potential segment end (and tolerate some silence)
temp_end = 0
@@ -101,12 +97,12 @@ def get_speech_timestamps(
prev_end = next_start = 0
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end:
if (speech_prob >= onset) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered:
if (speech_prob >= onset) and not triggered:
triggered = True
current_speech["start"] = window_size_samples * i
continue
@@ -133,7 +129,7 @@ def get_speech_timestamps(
triggered = False
continue
if (speech_prob < neg_threshold) and triggered:
if (speech_prob < offset) and triggered:
if not temp_end:
temp_end = window_size_samples * i
# condition to avoid cutting in very short silence
@@ -184,12 +180,27 @@ def get_speech_timestamps(
return speeches
def collect_chunks(audio: torch.Tensor, chunks: List[dict]) -> torch.Tensor:
"""Collects and concatenates audio chunks."""
def collect_chunks(
audio: torch.Tensor, chunks: List[dict], sampling_rate: int = 16000
) -> Tuple[List[torch.Tensor], List[Dict[str, int]]]:
"""Collects audio chunks."""
if not chunks:
return torch.tensor([], dtype=torch.float32)
chunk_metadata = {
"start_time": 0,
"end_time": 0,
}
return [torch.tensor([], dtype=torch.float32)], [chunk_metadata]
return torch.cat([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
audio_chunks = []
chunks_metadata = []
for chunk in chunks:
chunk_metadata = {
"start_time": chunk["start"] / sampling_rate,
"end_time": chunk["end"] / sampling_rate,
}
audio_chunks.append(audio[chunk["start"] : chunk["end"]])
chunks_metadata.append(chunk_metadata)
return audio_chunks, chunks_metadata
class SpeechTimestampsMap:
@@ -233,12 +244,13 @@ class SpeechTimestampsMap:
@functools.lru_cache
def get_vad_model():
"""Returns the VAD model instance."""
path = os.path.join(get_assets_path(), "silero_vad.onnx")
return SileroVADModel(path)
encoder_path = os.path.join(get_assets_path(), "silero_encoder_v5.onnx")
decoder_path = os.path.join(get_assets_path(), "silero_decoder_v5.onnx")
return SileroVADModel(encoder_path, decoder_path)
class SileroVADModel:
def __init__(self, path):
def __init__(self, encoder_path, decoder_path):
try:
import onnxruntime
except ImportError as e:
@@ -247,331 +259,84 @@ class SileroVADModel:
) from e
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.inter_op_num_threads = 0
opts.intra_op_num_threads = 0
opts.log_severity_level = 4
self.session = onnxruntime.InferenceSession(
path,
self.encoder_session = onnxruntime.InferenceSession(
encoder_path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
self.decoder_session = onnxruntime.InferenceSession(
decoder_path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def get_initial_states(self, batch_size: int):
state = np.zeros((2, batch_size, 128), dtype=np.float32)
context = np.zeros((batch_size, 64), dtype=np.float32)
return state, context
def __call__(self, x, state, context, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
x = np.concatenate([context, x], axis=1)
ort_inputs = {
"input": x,
"state": state,
"sr": np.array(sr, dtype="int64"),
}
out, state = self.session.run(None, ort_inputs)
context = x[..., -64:]
return out, state, context
# BSD 2-Clause License
# Copyright (c) 2024, Max Bain
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# The code below is copied from whisper-x (https://github.com/m-bain/whisperX)
# and adapted for faster_whisper.
class SegmentX:
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end
self.speaker = speaker
class VoiceActivitySegmentation(VoiceActivityDetection, ABC):
"""Pipeline wrapper class for Voice Activity Segmentation based on VAD scores."""
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
device: Optional[Union[str, torch.device]] = None,
fscore: bool = False,
use_auth_token: Optional[str] = None,
**inference_kwargs,
def __call__(
self, audio: np.ndarray, num_samples: int = 512, context_size_samples: int = 64
):
"""Initialize the pipeline with the model name and the optional device.
assert (
audio.ndim == 2
), "Input should be a 2D tensor with size (batch_size, num_samples)"
assert (
audio.shape[1] % num_samples == 0
), "Input size should be a multiple of num_samples"
Args:
dict parameters of VoiceActivityDetection class from pyannote:
segmentation (PipelineModel): Loaded model name.
device (torch.device or None): Device to perform the segmentation.
fscore (bool): Flag indicating whether to compute F-score during inference.
use_auth_token (str or None): Optional authentication token for model access.
inference_kwargs (dict): Additional arguments from VoiceActivityDetection pipeline.
"""
super().__init__(
segmentation=segmentation,
device=device,
fscore=fscore,
use_auth_token=use_auth_token,
**inference_kwargs,
batch_size = audio.shape[0]
state = np.zeros((2, batch_size, 128), dtype="float32")
context = np.zeros(
(batch_size, context_size_samples),
dtype="float32",
)
def apply(
self, file: AudioFile, hook: Optional[Callable] = None
) -> SlidingWindowFeature:
"""Apply voice activity detection on the audio file.
batched_audio = audio.reshape(batch_size, -1, num_samples)
context = batched_audio[..., -context_size_samples:]
context[:, -1] = 0
context = np.roll(context, 1, 1)
batched_audio = np.concatenate([context, batched_audio], 2)
Args:
file (AudioFile): Processed file.
hook (callable): Hook called with signature: hook("step_name", step_artefact, file=file)
batched_audio = batched_audio.reshape(-1, num_samples + context_size_samples)
Returns:
segmentations (SlidingWindowFeature): Voice activity segmentation.
"""
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)
encoder_output = self.encoder_session.run(None, {"input": batched_audio})[0]
encoder_output = encoder_output.reshape(batch_size, -1, 128)
# apply segmentation model if needed
# output shape is (num_chunks, num_frames, 1)
if self.training:
if self.CACHED_SEGMENTATION in file:
segmentations = file[self.CACHED_SEGMENTATION]
else:
segmentations = self._segmentation(file)
file[self.CACHED_SEGMENTATION] = segmentations
else:
segmentations: SlidingWindowFeature = self._segmentation(file)
return segmentations
class BinarizeVadScores:
"""Binarize detection scores using hysteresis thresholding.
Reference:
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float("inf"),
):
"""Initializes the parameters for Binarizing the VAD scores.
Args:
onset (float, optional):
Onset threshold. Defaults to 0.5.
offset (float, optional):
Offset threshold. Defaults to `onset`.
min_duration_on (float, optional):
Remove active regions shorter than that many seconds. Defaults to 0s.
min_duration_off (float, optional):
Fill inactive regions shorter than that many seconds. Defaults to 0s.
pad_onset (float, optional):
Extend active regions by moving their start time by that many seconds.
Defaults to 0s.
pad_offset (float, optional):
Extend active regions by moving their end time by that many seconds.
Defaults to 0s.
max_duration (float):
The maximum length of an active segment.
"""
super().__init__()
self.onset = onset
self.offset = offset or onset
self.pad_onset = pad_onset
self.pad_offset = pad_offset
self.min_duration_on = min_duration_on
self.min_duration_off = min_duration_off
self.max_duration = max_duration
def __get_active_regions(self, scores: SlidingWindowFeature) -> Annotation:
"""Extract active regions from VAD scores.
Args:
scores (SlidingWindowFeature): Detection scores.
Returns:
active (Annotation): Active regions.
"""
num_frames, num_classes = scores.data.shape
frames = scores.sliding_window
timestamps = [frames[i].middle for i in range(num_frames)]
# annotation meant to store 'active' regions
active = Annotation()
for k, k_scores in enumerate(scores.data.T):
label = k if scores.labels is None else scores.labels[k]
# initial state
start = timestamps[0]
is_active = k_scores[0] > self.onset
curr_scores = [k_scores[0]]
curr_timestamps = [start]
t = start
# optionally add `strict=False` for python 3.10 or later
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(
curr_scores[search_after:]
)
min_score_t = curr_timestamps[min_score_div_idx]
region = Segment(
start - self.pad_onset, min_score_t + self.pad_offset
)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx + 1 :]
curr_timestamps = curr_timestamps[min_score_div_idx + 1 :]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
is_active = False
curr_scores = []
curr_timestamps = []
curr_scores.append(y)
curr_timestamps.append(t)
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
# if active at the end, add final region
if is_active:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
return active
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
"""Binarize detection scores.
Args:
scores (SlidingWindowFeature): Detection scores.
Returns:
active (Annotation): Binarized scores.
"""
active = self.__get_active_regions(scores)
# because of padding, some active regions might be overlapping: merge them.
# also: fill same speaker gaps shorter than min_duration_off
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
if self.max_duration < float("inf"):
raise NotImplementedError("This would break current max_duration param")
active = active.support(collar=self.min_duration_off)
# remove tracks shorter than min_duration_on
if self.min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < self.min_duration_on:
del active[segment, track]
return active
def merge_chunks(
segments,
chunk_length,
onset: float = 0.5,
offset: Optional[float] = None,
edge_padding: float = 0.1,
):
"""
Merge operation described in whisper-x paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
assert chunk_length > 0
binarize = BinarizeVadScores(max_duration=chunk_length, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(
SegmentX(
max(0.0, speech_turn.start - edge_padding),
speech_turn.end + edge_padding,
"UNKNOWN",
decoder_outputs = []
for window in np.split(encoder_output, encoder_output.shape[1], axis=1):
out, state = self.decoder_session.run(
None, {"input": window.squeeze(1), "state": state}
)
) # 100ms edge padding to account for edge errors
decoder_outputs.append(out)
if len(segments_list) == 0:
print("No active speech found in audio")
out = np.stack(decoder_outputs, axis=1).squeeze(-1)
return out
def merge_segments(segments_list, vad_options: VadOptions, sampling_rate: int = 16000):
if not segments_list:
return []
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
curr_end = 0
seg_idxs = []
merged_segments = []
edge_padding = vad_options.speech_pad_ms * sampling_rate // 1000
chunk_length = vad_options.max_speech_duration_s * sampling_rate
curr_start = segments_list[0]["start"]
for idx, seg in enumerate(segments_list):
# if any segment start timing is less than previous segment end timing,
# reset the edge padding. Similarly for end timing.
if idx > 0:
if seg.start < segments_list[idx - 1].end:
seg.start += edge_padding
if seg["start"] < segments_list[idx - 1]["end"]:
seg["start"] += edge_padding
if idx < len(segments_list) - 1:
if seg.end > segments_list[idx + 1].start:
seg.end -= edge_padding
if seg["end"] > segments_list[idx + 1]["start"]:
seg["end"] -= edge_padding
if seg.end - curr_start > chunk_length and curr_end - curr_start > 0:
if seg["end"] - curr_start > chunk_length and curr_end - curr_start > 0:
merged_segments.append(
{
"start": curr_start,
@@ -579,12 +344,10 @@ def merge_chunks(
"segments": seg_idxs,
}
)
curr_start = seg.start
curr_start = seg["start"]
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
curr_end = seg["end"]
seg_idxs.append((seg["start"], seg["end"]))
# add final
merged_segments.append(
{