mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 13:38:01 -05:00
Use Silero VAD in Batched Mode (#936)
Replace Pyannote VAD with Silero to reduce code duplication and requirements
This commit is contained in:
Binary file not shown.
BIN
faster_whisper/assets/silero_decoder_v5.onnx
Normal file
BIN
faster_whisper/assets/silero_decoder_v5.onnx
Normal file
Binary file not shown.
BIN
faster_whisper/assets/silero_encoder_v5.onnx
Normal file
BIN
faster_whisper/assets/silero_encoder_v5.onnx
Normal file
Binary file not shown.
Binary file not shown.
@@ -7,6 +7,7 @@ import zlib
|
||||
|
||||
from collections import Counter, defaultdict
|
||||
from inspect import signature
|
||||
from math import ceil
|
||||
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import ctranslate2
|
||||
@@ -14,26 +15,18 @@ import numpy as np
|
||||
import tokenizers
|
||||
import torch
|
||||
|
||||
from pyannote.audio import Model
|
||||
from tqdm import tqdm
|
||||
|
||||
from faster_whisper.audio import decode_audio, pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
|
||||
from faster_whisper.utils import (
|
||||
download_model,
|
||||
format_timestamp,
|
||||
get_assets_path,
|
||||
get_end,
|
||||
get_logger,
|
||||
)
|
||||
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
|
||||
from faster_whisper.vad import (
|
||||
SpeechTimestampsMap,
|
||||
VadOptions,
|
||||
VoiceActivitySegmentation,
|
||||
collect_chunks,
|
||||
get_speech_timestamps,
|
||||
merge_chunks,
|
||||
merge_segments,
|
||||
)
|
||||
|
||||
|
||||
@@ -115,67 +108,26 @@ class BatchedInferencePipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
use_vad_model: bool = True,
|
||||
options: Optional[NamedTuple] = None,
|
||||
tokenizer=None,
|
||||
chunk_length: int = 30,
|
||||
vad_device: Union[int, str, "torch.device"] = "auto",
|
||||
vad_onset: float = 0.500,
|
||||
vad_offset: float = 0.363,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
self.model: WhisperModel = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self.use_vad_model = use_vad_model
|
||||
self.vad_onset = vad_onset
|
||||
self.vad_offset = vad_offset
|
||||
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
|
||||
if self.use_vad_model:
|
||||
self.vad_device = self.get_device(vad_device)
|
||||
self.vad_model = self.load_vad_model(
|
||||
vad_onset=self.vad_onset, vad_offset=self.vad_offset
|
||||
)
|
||||
else:
|
||||
self.vad_model = None
|
||||
self.chunk_length = chunk_length # VAD merging size
|
||||
self.last_speech_timestamp = 0.0
|
||||
|
||||
def get_device(self, device: Union[int, str, "torch.device"]):
|
||||
"""
|
||||
Converts the input device into a torch.device object.
|
||||
|
||||
The input can be an integer, a string, or a `torch.device` object.
|
||||
|
||||
The function handles a special case where the input device is "auto".
|
||||
When "auto" is specified, the device will default to the
|
||||
device of the model (self.model.device). If the model's device is also "auto",
|
||||
it selects "cuda" if a CUDA-capable device is available; otherwise, it selects "cpu".
|
||||
"""
|
||||
if isinstance(device, torch.device):
|
||||
return device
|
||||
elif isinstance(device, str):
|
||||
if device == "auto" and self.model.device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
elif device == "auto":
|
||||
device = self.model.device
|
||||
return torch.device(device)
|
||||
elif device < 0:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
return torch.device(f"cuda:{device}")
|
||||
|
||||
def forward(self, features, segments_metadata, **forward_params):
|
||||
def forward(self, features, chunks_metadata, **forward_params):
|
||||
encoder_output, outputs = self.model.generate_segment_batched(
|
||||
features, self.tokenizer, forward_params
|
||||
)
|
||||
|
||||
segmented_outputs = []
|
||||
segment_sizes = []
|
||||
for segment_metadata, output in zip(segments_metadata, outputs):
|
||||
duration = segment_metadata["end_time"] - segment_metadata["start_time"]
|
||||
segment_size = int(duration * self.model.frames_per_second)
|
||||
for chunk_metadata, output in zip(chunks_metadata, outputs):
|
||||
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
|
||||
segment_size = int(ceil(duration) * self.model.frames_per_second)
|
||||
segment_sizes.append(segment_size)
|
||||
(
|
||||
subsegments,
|
||||
@@ -184,7 +136,7 @@ class BatchedInferencePipeline:
|
||||
) = self.model._split_segments_by_timestamps(
|
||||
tokenizer=self.tokenizer,
|
||||
tokens=output["tokens"],
|
||||
time_offset=segment_metadata["start_time"],
|
||||
time_offset=chunk_metadata["start_time"],
|
||||
segment_size=segment_size,
|
||||
segment_duration=duration,
|
||||
seek=0,
|
||||
@@ -252,43 +204,9 @@ class BatchedInferencePipeline:
|
||||
|
||||
return language, language_probability, task, all_language_probs
|
||||
|
||||
@staticmethod
|
||||
def audio_split(audio, segments, sampling_rate):
|
||||
"""Returns splitted audio chunks as iterator"""
|
||||
audio_segments = []
|
||||
segments_metadata = []
|
||||
for seg in segments:
|
||||
f1 = int(seg["start"] * sampling_rate)
|
||||
f2 = int(seg["end"] * sampling_rate)
|
||||
seg_metadata = {
|
||||
"start_time": seg["start"],
|
||||
"end_time": seg["end"],
|
||||
"stitched_seg": seg["segments"],
|
||||
}
|
||||
audio_segments.append(audio[f1:f2])
|
||||
segments_metadata.append(seg_metadata)
|
||||
return audio_segments, segments_metadata
|
||||
|
||||
def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
|
||||
vad_model = Model.from_pretrained(self.vad_model_path)
|
||||
hyperparameters = {
|
||||
"onset": vad_onset,
|
||||
"offset": vad_offset,
|
||||
"min_duration_on": 0.1,
|
||||
"min_duration_off": 0.1,
|
||||
}
|
||||
|
||||
vad_pipeline = VoiceActivitySegmentation(
|
||||
segmentation=vad_model, device=torch.device(self.vad_device)
|
||||
)
|
||||
vad_pipeline.instantiate(hyperparameters)
|
||||
return vad_pipeline
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, torch.Tensor, np.ndarray],
|
||||
vad_segments: Optional[List[dict]] = None,
|
||||
batch_size: int = 16,
|
||||
audio: Union[str, BinaryIO, torch.Tensor, np.ndarray],
|
||||
language: Optional[str] = None,
|
||||
task: str = None,
|
||||
log_progress: bool = False,
|
||||
@@ -314,26 +232,26 @@ class BatchedInferencePipeline:
|
||||
prefix: Optional[str] = None,
|
||||
suppress_blank: bool = True,
|
||||
suppress_tokens: Optional[List[int]] = [-1],
|
||||
without_timestamps: bool = True,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
vad_filter: bool = True,
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
chunk_length: Optional[int] = None,
|
||||
clip_timestamps: Optional[List[dict]] = None,
|
||||
batch_size: int = 16,
|
||||
hotwords: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
without_timestamps: bool = True,
|
||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||
"""transcribe audio in chunks in batched fashion and return with language info.
|
||||
|
||||
Arguments:
|
||||
audio: audio file as numpy array/path for batched transcription.
|
||||
vad_segments: Optionally provide list of dictionaries each containing "start", "end",
|
||||
and "segments" keys.
|
||||
"start" and "end" keys specify the start and end of the voiced region within
|
||||
30 sec boundary. An additional key "segments" contains all the start
|
||||
and end of voiced regions within that 30sec boundary as a list of tuples.
|
||||
If no vad_segments specified, it uses internal vad model automatically segment them.
|
||||
batch_size: the maximum number of parallel requests to model for decoding.
|
||||
language: The language spoken in the audio.
|
||||
task: either "transcribe" or "translate".
|
||||
audio: Path to the input file (or a file-like object), or the audio waveform.
|
||||
language: The language spoken in the audio. It should be a language code such
|
||||
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
|
||||
of audio.
|
||||
task: Task to execute (transcribe or translate).
|
||||
log_progress: whether to show progress bar or not.
|
||||
beam_size: Beam size to use for decoding.
|
||||
best_of: Number of candidates when sampling with non-zero temperature.
|
||||
@@ -350,8 +268,8 @@ class BatchedInferencePipeline:
|
||||
log_prob_threshold: If the average log probability over sampled tokens is
|
||||
below this value, treat as failed.
|
||||
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
|
||||
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
|
||||
This value should be less than log_prob_threshold.
|
||||
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
|
||||
This value should be less than log_prob_threshold.
|
||||
no_speech_threshold: If the no_speech probability is higher than this value AND
|
||||
the average log probability over sampled tokens is below `log_prob_threshold`,
|
||||
consider the segment as silent.
|
||||
@@ -361,18 +279,29 @@ class BatchedInferencePipeline:
|
||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||
of symbols as defined in `tokenizer.non_speech_tokens()`.
|
||||
without_timestamps: Only sample text tokens.
|
||||
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||
and dynamic time warping, and include the timestamps for each word in each segment.
|
||||
Set as False.
|
||||
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||
with the next word
|
||||
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||
with the previous word
|
||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||
without speech. This step is using the Silero VAD model
|
||||
https://github.com/snakers4/silero-vad.
|
||||
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
||||
parameters and default values in the class `VadOptions`).
|
||||
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
|
||||
the maximum will be set by the default max_length.
|
||||
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
||||
default chunk_length of the FeatureExtractor.
|
||||
clip_timestamps: Optionally provide list of dictionaries each containing "start" and
|
||||
"end" keys that specify the start and end of the voiced region within
|
||||
`chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
|
||||
batch_size: the maximum number of parallel requests to model for decoding.
|
||||
hotwords:
|
||||
Hotwords/hint phrases to the model. Has no effect if prefix is not None.
|
||||
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||
and dynamic time warping, and include the timestamps for each word in each segment.
|
||||
Set as False.
|
||||
without_timestamps: Only sample text tokens.
|
||||
|
||||
Static params: (Fixed for batched version)
|
||||
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
|
||||
@@ -390,28 +319,18 @@ class BatchedInferencePipeline:
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold
|
||||
(in seconds) when a possible hallucination is detected. set as None.
|
||||
clip_timestamps:
|
||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
|
||||
process. The last end timestamp defaults to the end of the file. Set as "0".
|
||||
|
||||
unused:
|
||||
language_detection_threshold: If the maximum probability of the language tokens is
|
||||
higher than this value, the language is detected.
|
||||
language_detection_segments: Number of segments to consider for the language detection.
|
||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||
without speech. This step is using the Silero VAD model
|
||||
https://github.com/snakers4/silero-vad.
|
||||
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
||||
parameters and default values in the class `VadOptions`).
|
||||
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
||||
default chunk_length of the FeatureExtractor.
|
||||
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
|
||||
- a generator over transcribed batched segments.
|
||||
- an instance of TranscriptionInfo.
|
||||
- a generator over transcribed segments
|
||||
- an instance of TranscriptionInfo
|
||||
"""
|
||||
|
||||
sampling_rate = self.model.feature_extractor.sampling_rate
|
||||
@@ -422,29 +341,32 @@ class BatchedInferencePipeline:
|
||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
|
||||
chunk_length = chunk_length or self.model.feature_extractor.chunk_length
|
||||
# if no segment split is provided, use vad_model and generate segments
|
||||
if not vad_segments:
|
||||
# run the audio if it is less than 30 sec even without vad_segments
|
||||
if self.use_vad_model:
|
||||
vad_segments = self.vad_model(
|
||||
{
|
||||
"waveform": audio.unsqueeze(0),
|
||||
"sample_rate": 16000,
|
||||
}
|
||||
)
|
||||
vad_segments = merge_chunks(
|
||||
vad_segments,
|
||||
self.chunk_length,
|
||||
onset=self.vad_onset,
|
||||
offset=self.vad_offset,
|
||||
)
|
||||
elif duration < self.chunk_length:
|
||||
vad_segments = [
|
||||
{"start": 0.0, "end": duration, "segments": [(0.0, duration)]}
|
||||
]
|
||||
if not clip_timestamps:
|
||||
if vad_filter:
|
||||
if vad_parameters is None:
|
||||
vad_parameters = VadOptions(
|
||||
max_speech_duration_s=chunk_length,
|
||||
min_silence_duration_ms=160,
|
||||
)
|
||||
elif isinstance(vad_parameters, dict):
|
||||
if "max_speech_duration_s" in vad_parameters.keys():
|
||||
vad_parameters.pop("max_speech_duration_s")
|
||||
|
||||
vad_parameters = VadOptions(
|
||||
**vad_parameters, max_speech_duration_s=chunk_length
|
||||
)
|
||||
|
||||
active_segments = get_speech_timestamps(audio, vad_parameters)
|
||||
clip_timestamps = merge_segments(active_segments, vad_parameters)
|
||||
# run the audio if it is less than 30 sec even without clip_timestamps
|
||||
elif duration < chunk_length:
|
||||
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No vad segments found. Set 'use_vad_model' to True while loading the model"
|
||||
"No clip timestamps found. "
|
||||
"Set 'vad_filter' to True or provide 'clip_timestamps'."
|
||||
)
|
||||
if self.model.model.is_multilingual:
|
||||
language = language or self.preset_language
|
||||
@@ -452,7 +374,7 @@ class BatchedInferencePipeline:
|
||||
if language is not None:
|
||||
self.model.logger.warning(
|
||||
f"English-only model is used, but {language} language is"
|
||||
"chosen, setting language to 'en'."
|
||||
" chosen, setting language to 'en'."
|
||||
)
|
||||
language = "en"
|
||||
|
||||
@@ -463,8 +385,9 @@ class BatchedInferencePipeline:
|
||||
all_language_probs,
|
||||
) = self.get_language_and_tokenizer(audio, task, language)
|
||||
|
||||
duration_after_vad = sum(
|
||||
segment["end"] - segment["start"] for segment in vad_segments
|
||||
duration_after_vad = (
|
||||
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
|
||||
/ sampling_rate
|
||||
)
|
||||
|
||||
# batched options: see the difference with default options in WhisperModel
|
||||
@@ -511,27 +434,26 @@ class BatchedInferencePipeline:
|
||||
all_language_probs=all_language_probs,
|
||||
)
|
||||
|
||||
audio_segments, segments_metadata = self.audio_split(
|
||||
audio, vad_segments, sampling_rate
|
||||
)
|
||||
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
|
||||
to_cpu = (
|
||||
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
|
||||
)
|
||||
audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor(
|
||||
padding=0
|
||||
)
|
||||
features = torch.stack(
|
||||
[
|
||||
self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[
|
||||
..., : self.model.feature_extractor.nb_max_frames
|
||||
features = (
|
||||
torch.stack(
|
||||
[
|
||||
self.model.feature_extractor(chunk, to_cpu=to_cpu)[
|
||||
..., : self.model.feature_extractor.nb_max_frames
|
||||
]
|
||||
for chunk in audio_chunks
|
||||
]
|
||||
for audio_segment in audio_segments
|
||||
]
|
||||
)
|
||||
if duration_after_vad
|
||||
else []
|
||||
)
|
||||
|
||||
segments = self._batched_segments_generator(
|
||||
features,
|
||||
segments_metadata,
|
||||
chunks_metadata,
|
||||
batch_size,
|
||||
batched_options,
|
||||
log_progress,
|
||||
@@ -540,14 +462,14 @@ class BatchedInferencePipeline:
|
||||
return segments, info
|
||||
|
||||
def _batched_segments_generator(
|
||||
self, features, segments_metadata, batch_size, options, log_progress
|
||||
self, features, chunks_metadata, batch_size, options, log_progress
|
||||
):
|
||||
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
|
||||
seg_idx = 0
|
||||
for i in range(0, len(features), batch_size):
|
||||
results = self.forward(
|
||||
features[i : i + batch_size],
|
||||
segments_metadata[i : i + batch_size],
|
||||
chunks_metadata[i : i + batch_size],
|
||||
**options._asdict(),
|
||||
)
|
||||
|
||||
@@ -850,7 +772,8 @@ class WhisperModel:
|
||||
elif isinstance(vad_parameters, dict):
|
||||
vad_parameters = VadOptions(**vad_parameters)
|
||||
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
||||
audio = collect_chunks(audio, speech_chunks)
|
||||
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
|
||||
audio = torch.cat(audio_chunks, dim=0)
|
||||
duration_after_vad = audio.shape[0] / sampling_rate
|
||||
|
||||
self.logger.info(
|
||||
@@ -1905,7 +1828,8 @@ class WhisperModel:
|
||||
# get chunks of audio that contain speech
|
||||
speech_chunks = get_speech_timestamps(audio, vad_params)
|
||||
# merge chunks of audio that contain speech into a single array
|
||||
audio = collect_chunks(audio, speech_chunks)
|
||||
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
|
||||
audio = torch.cat(audio_chunks, dim=0)
|
||||
|
||||
# calculate new duration of audio without silence
|
||||
duration_vad = audio.shape[0] / sampling_rate
|
||||
|
||||
@@ -2,18 +2,11 @@ import bisect
|
||||
import functools
|
||||
import os
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pyannote.audio.core.io import AudioFile
|
||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||
|
||||
from faster_whisper.utils import get_assets_path
|
||||
|
||||
|
||||
@@ -22,9 +15,14 @@ class VadOptions(NamedTuple):
|
||||
"""VAD options.
|
||||
|
||||
Attributes:
|
||||
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
||||
onset: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
||||
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
||||
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
offset: Silence threshold for determining the end of speech. If a probability is lower than
|
||||
the offset, it is always considered silence. Values higher than offset are only considered
|
||||
speech if the previous sample was classified as speech; otherwise, they are treated as
|
||||
silence. This parameter helps refine the detection of speech transitions, ensuring smoother
|
||||
segment boundaries.
|
||||
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
||||
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
||||
than max_speech_duration_s will be split at the timestamp of the last silence that
|
||||
@@ -35,8 +33,9 @@ class VadOptions(NamedTuple):
|
||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||
"""
|
||||
|
||||
threshold: float = 0.5
|
||||
min_speech_duration_ms: int = 250
|
||||
onset: float = 0.5
|
||||
offset: float = onset - 0.15
|
||||
min_speech_duration_ms: int = 0
|
||||
max_speech_duration_s: float = float("inf")
|
||||
min_silence_duration_ms: int = 2000
|
||||
speech_pad_ms: int = 400
|
||||
@@ -45,6 +44,7 @@ class VadOptions(NamedTuple):
|
||||
def get_speech_timestamps(
|
||||
audio: torch.Tensor,
|
||||
vad_options: Optional[VadOptions] = None,
|
||||
sampling_rate: int = 16000,
|
||||
**kwargs,
|
||||
) -> List[dict]:
|
||||
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
||||
@@ -52,6 +52,7 @@ def get_speech_timestamps(
|
||||
Args:
|
||||
audio: One dimensional float array.
|
||||
vad_options: Options for VAD processing.
|
||||
sampling rate: Sampling rate of the audio.
|
||||
kwargs: VAD options passed as keyword arguments for backward compatibility.
|
||||
|
||||
Returns:
|
||||
@@ -60,13 +61,12 @@ def get_speech_timestamps(
|
||||
if vad_options is None:
|
||||
vad_options = VadOptions(**kwargs)
|
||||
|
||||
threshold = vad_options.threshold
|
||||
onset = vad_options.onset
|
||||
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
||||
max_speech_duration_s = vad_options.max_speech_duration_s
|
||||
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||
window_size_samples = 512
|
||||
speech_pad_ms = vad_options.speech_pad_ms
|
||||
sampling_rate = 16000
|
||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
max_speech_samples = (
|
||||
@@ -80,20 +80,16 @@ def get_speech_timestamps(
|
||||
audio_length_samples = len(audio)
|
||||
|
||||
model = get_vad_model()
|
||||
state, context = model.get_initial_states(batch_size=1)
|
||||
|
||||
speech_probs = []
|
||||
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
||||
if len(chunk) < window_size_samples:
|
||||
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||
speech_prob, state, context = model(chunk, state, context, sampling_rate)
|
||||
speech_probs.append(speech_prob)
|
||||
padded_audio = np.pad(
|
||||
audio.numpy(), (0, window_size_samples - audio.shape[0] % window_size_samples)
|
||||
)
|
||||
speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0)
|
||||
|
||||
triggered = False
|
||||
speeches = []
|
||||
current_speech = {}
|
||||
neg_threshold = threshold - 0.15
|
||||
offset = vad_options.offset
|
||||
|
||||
# to save potential segment end (and tolerate some silence)
|
||||
temp_end = 0
|
||||
@@ -101,12 +97,12 @@ def get_speech_timestamps(
|
||||
prev_end = next_start = 0
|
||||
|
||||
for i, speech_prob in enumerate(speech_probs):
|
||||
if (speech_prob >= threshold) and temp_end:
|
||||
if (speech_prob >= onset) and temp_end:
|
||||
temp_end = 0
|
||||
if next_start < prev_end:
|
||||
next_start = window_size_samples * i
|
||||
|
||||
if (speech_prob >= threshold) and not triggered:
|
||||
if (speech_prob >= onset) and not triggered:
|
||||
triggered = True
|
||||
current_speech["start"] = window_size_samples * i
|
||||
continue
|
||||
@@ -133,7 +129,7 @@ def get_speech_timestamps(
|
||||
triggered = False
|
||||
continue
|
||||
|
||||
if (speech_prob < neg_threshold) and triggered:
|
||||
if (speech_prob < offset) and triggered:
|
||||
if not temp_end:
|
||||
temp_end = window_size_samples * i
|
||||
# condition to avoid cutting in very short silence
|
||||
@@ -184,12 +180,27 @@ def get_speech_timestamps(
|
||||
return speeches
|
||||
|
||||
|
||||
def collect_chunks(audio: torch.Tensor, chunks: List[dict]) -> torch.Tensor:
|
||||
"""Collects and concatenates audio chunks."""
|
||||
def collect_chunks(
|
||||
audio: torch.Tensor, chunks: List[dict], sampling_rate: int = 16000
|
||||
) -> Tuple[List[torch.Tensor], List[Dict[str, int]]]:
|
||||
"""Collects audio chunks."""
|
||||
if not chunks:
|
||||
return torch.tensor([], dtype=torch.float32)
|
||||
chunk_metadata = {
|
||||
"start_time": 0,
|
||||
"end_time": 0,
|
||||
}
|
||||
return [torch.tensor([], dtype=torch.float32)], [chunk_metadata]
|
||||
|
||||
return torch.cat([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
|
||||
audio_chunks = []
|
||||
chunks_metadata = []
|
||||
for chunk in chunks:
|
||||
chunk_metadata = {
|
||||
"start_time": chunk["start"] / sampling_rate,
|
||||
"end_time": chunk["end"] / sampling_rate,
|
||||
}
|
||||
audio_chunks.append(audio[chunk["start"] : chunk["end"]])
|
||||
chunks_metadata.append(chunk_metadata)
|
||||
return audio_chunks, chunks_metadata
|
||||
|
||||
|
||||
class SpeechTimestampsMap:
|
||||
@@ -233,12 +244,13 @@ class SpeechTimestampsMap:
|
||||
@functools.lru_cache
|
||||
def get_vad_model():
|
||||
"""Returns the VAD model instance."""
|
||||
path = os.path.join(get_assets_path(), "silero_vad.onnx")
|
||||
return SileroVADModel(path)
|
||||
encoder_path = os.path.join(get_assets_path(), "silero_encoder_v5.onnx")
|
||||
decoder_path = os.path.join(get_assets_path(), "silero_decoder_v5.onnx")
|
||||
return SileroVADModel(encoder_path, decoder_path)
|
||||
|
||||
|
||||
class SileroVADModel:
|
||||
def __init__(self, path):
|
||||
def __init__(self, encoder_path, decoder_path):
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError as e:
|
||||
@@ -247,331 +259,84 @@ class SileroVADModel:
|
||||
) from e
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
opts.inter_op_num_threads = 0
|
||||
opts.intra_op_num_threads = 0
|
||||
opts.log_severity_level = 4
|
||||
|
||||
self.session = onnxruntime.InferenceSession(
|
||||
path,
|
||||
self.encoder_session = onnxruntime.InferenceSession(
|
||||
encoder_path,
|
||||
providers=["CPUExecutionProvider"],
|
||||
sess_options=opts,
|
||||
)
|
||||
self.decoder_session = onnxruntime.InferenceSession(
|
||||
decoder_path,
|
||||
providers=["CPUExecutionProvider"],
|
||||
sess_options=opts,
|
||||
)
|
||||
|
||||
def get_initial_states(self, batch_size: int):
|
||||
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||
context = np.zeros((batch_size, 64), dtype=np.float32)
|
||||
return state, context
|
||||
|
||||
def __call__(self, x, state, context, sr: int):
|
||||
if len(x.shape) == 1:
|
||||
x = np.expand_dims(x, 0)
|
||||
if len(x.shape) > 2:
|
||||
raise ValueError(
|
||||
f"Too many dimensions for input audio chunk {len(x.shape)}"
|
||||
)
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
x = np.concatenate([context, x], axis=1)
|
||||
|
||||
ort_inputs = {
|
||||
"input": x,
|
||||
"state": state,
|
||||
"sr": np.array(sr, dtype="int64"),
|
||||
}
|
||||
|
||||
out, state = self.session.run(None, ort_inputs)
|
||||
context = x[..., -64:]
|
||||
|
||||
return out, state, context
|
||||
|
||||
|
||||
# BSD 2-Clause License
|
||||
|
||||
# Copyright (c) 2024, Max Bain
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
# The code below is copied from whisper-x (https://github.com/m-bain/whisperX)
|
||||
# and adapted for faster_whisper.
|
||||
class SegmentX:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.speaker = speaker
|
||||
|
||||
|
||||
class VoiceActivitySegmentation(VoiceActivityDetection, ABC):
|
||||
"""Pipeline wrapper class for Voice Activity Segmentation based on VAD scores."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
fscore: bool = False,
|
||||
use_auth_token: Optional[str] = None,
|
||||
**inference_kwargs,
|
||||
def __call__(
|
||||
self, audio: np.ndarray, num_samples: int = 512, context_size_samples: int = 64
|
||||
):
|
||||
"""Initialize the pipeline with the model name and the optional device.
|
||||
assert (
|
||||
audio.ndim == 2
|
||||
), "Input should be a 2D tensor with size (batch_size, num_samples)"
|
||||
assert (
|
||||
audio.shape[1] % num_samples == 0
|
||||
), "Input size should be a multiple of num_samples"
|
||||
|
||||
Args:
|
||||
dict parameters of VoiceActivityDetection class from pyannote:
|
||||
segmentation (PipelineModel): Loaded model name.
|
||||
device (torch.device or None): Device to perform the segmentation.
|
||||
fscore (bool): Flag indicating whether to compute F-score during inference.
|
||||
use_auth_token (str or None): Optional authentication token for model access.
|
||||
inference_kwargs (dict): Additional arguments from VoiceActivityDetection pipeline.
|
||||
"""
|
||||
super().__init__(
|
||||
segmentation=segmentation,
|
||||
device=device,
|
||||
fscore=fscore,
|
||||
use_auth_token=use_auth_token,
|
||||
**inference_kwargs,
|
||||
batch_size = audio.shape[0]
|
||||
|
||||
state = np.zeros((2, batch_size, 128), dtype="float32")
|
||||
context = np.zeros(
|
||||
(batch_size, context_size_samples),
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
def apply(
|
||||
self, file: AudioFile, hook: Optional[Callable] = None
|
||||
) -> SlidingWindowFeature:
|
||||
"""Apply voice activity detection on the audio file.
|
||||
batched_audio = audio.reshape(batch_size, -1, num_samples)
|
||||
context = batched_audio[..., -context_size_samples:]
|
||||
context[:, -1] = 0
|
||||
context = np.roll(context, 1, 1)
|
||||
batched_audio = np.concatenate([context, batched_audio], 2)
|
||||
|
||||
Args:
|
||||
file (AudioFile): Processed file.
|
||||
hook (callable): Hook called with signature: hook("step_name", step_artefact, file=file)
|
||||
batched_audio = batched_audio.reshape(-1, num_samples + context_size_samples)
|
||||
|
||||
Returns:
|
||||
segmentations (SlidingWindowFeature): Voice activity segmentation.
|
||||
"""
|
||||
# setup hook (e.g. for debugging purposes)
|
||||
hook = self.setup_hook(file, hook=hook)
|
||||
encoder_output = self.encoder_session.run(None, {"input": batched_audio})[0]
|
||||
encoder_output = encoder_output.reshape(batch_size, -1, 128)
|
||||
|
||||
# apply segmentation model if needed
|
||||
# output shape is (num_chunks, num_frames, 1)
|
||||
if self.training:
|
||||
if self.CACHED_SEGMENTATION in file:
|
||||
segmentations = file[self.CACHED_SEGMENTATION]
|
||||
else:
|
||||
segmentations = self._segmentation(file)
|
||||
file[self.CACHED_SEGMENTATION] = segmentations
|
||||
else:
|
||||
segmentations: SlidingWindowFeature = self._segmentation(file)
|
||||
|
||||
return segmentations
|
||||
|
||||
|
||||
class BinarizeVadScores:
|
||||
"""Binarize detection scores using hysteresis thresholding.
|
||||
|
||||
Reference:
|
||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
https://arxiv.org/abs/2303.00747
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
min_duration_on: float = 0.0,
|
||||
min_duration_off: float = 0.0,
|
||||
pad_onset: float = 0.0,
|
||||
pad_offset: float = 0.0,
|
||||
max_duration: float = float("inf"),
|
||||
):
|
||||
"""Initializes the parameters for Binarizing the VAD scores.
|
||||
|
||||
Args:
|
||||
onset (float, optional):
|
||||
Onset threshold. Defaults to 0.5.
|
||||
offset (float, optional):
|
||||
Offset threshold. Defaults to `onset`.
|
||||
min_duration_on (float, optional):
|
||||
Remove active regions shorter than that many seconds. Defaults to 0s.
|
||||
min_duration_off (float, optional):
|
||||
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
||||
pad_onset (float, optional):
|
||||
Extend active regions by moving their start time by that many seconds.
|
||||
Defaults to 0s.
|
||||
pad_offset (float, optional):
|
||||
Extend active regions by moving their end time by that many seconds.
|
||||
Defaults to 0s.
|
||||
max_duration (float):
|
||||
The maximum length of an active segment.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.onset = onset
|
||||
self.offset = offset or onset
|
||||
|
||||
self.pad_onset = pad_onset
|
||||
self.pad_offset = pad_offset
|
||||
|
||||
self.min_duration_on = min_duration_on
|
||||
self.min_duration_off = min_duration_off
|
||||
|
||||
self.max_duration = max_duration
|
||||
|
||||
def __get_active_regions(self, scores: SlidingWindowFeature) -> Annotation:
|
||||
"""Extract active regions from VAD scores.
|
||||
|
||||
Args:
|
||||
scores (SlidingWindowFeature): Detection scores.
|
||||
|
||||
Returns:
|
||||
active (Annotation): Active regions.
|
||||
"""
|
||||
num_frames, num_classes = scores.data.shape
|
||||
frames = scores.sliding_window
|
||||
timestamps = [frames[i].middle for i in range(num_frames)]
|
||||
# annotation meant to store 'active' regions
|
||||
active = Annotation()
|
||||
for k, k_scores in enumerate(scores.data.T):
|
||||
label = k if scores.labels is None else scores.labels[k]
|
||||
|
||||
# initial state
|
||||
start = timestamps[0]
|
||||
is_active = k_scores[0] > self.onset
|
||||
curr_scores = [k_scores[0]]
|
||||
curr_timestamps = [start]
|
||||
t = start
|
||||
# optionally add `strict=False` for python 3.10 or later
|
||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||
# currently active
|
||||
if is_active:
|
||||
curr_duration = t - start
|
||||
if curr_duration > self.max_duration:
|
||||
search_after = len(curr_scores) // 2
|
||||
# divide segment
|
||||
min_score_div_idx = search_after + np.argmin(
|
||||
curr_scores[search_after:]
|
||||
)
|
||||
min_score_t = curr_timestamps[min_score_div_idx]
|
||||
region = Segment(
|
||||
start - self.pad_onset, min_score_t + self.pad_offset
|
||||
)
|
||||
active[region, k] = label
|
||||
start = curr_timestamps[min_score_div_idx]
|
||||
curr_scores = curr_scores[min_score_div_idx + 1 :]
|
||||
curr_timestamps = curr_timestamps[min_score_div_idx + 1 :]
|
||||
# switching from active to inactive
|
||||
elif y < self.offset:
|
||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
start = t
|
||||
is_active = False
|
||||
curr_scores = []
|
||||
curr_timestamps = []
|
||||
curr_scores.append(y)
|
||||
curr_timestamps.append(t)
|
||||
# currently inactive
|
||||
else:
|
||||
# switching from inactive to active
|
||||
if y > self.onset:
|
||||
start = t
|
||||
is_active = True
|
||||
|
||||
# if active at the end, add final region
|
||||
if is_active:
|
||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
|
||||
return active
|
||||
|
||||
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
||||
"""Binarize detection scores.
|
||||
|
||||
Args:
|
||||
scores (SlidingWindowFeature): Detection scores.
|
||||
|
||||
Returns:
|
||||
active (Annotation): Binarized scores.
|
||||
"""
|
||||
active = self.__get_active_regions(scores)
|
||||
# because of padding, some active regions might be overlapping: merge them.
|
||||
# also: fill same speaker gaps shorter than min_duration_off
|
||||
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
||||
if self.max_duration < float("inf"):
|
||||
raise NotImplementedError("This would break current max_duration param")
|
||||
active = active.support(collar=self.min_duration_off)
|
||||
|
||||
# remove tracks shorter than min_duration_on
|
||||
if self.min_duration_on > 0:
|
||||
for segment, track in list(active.itertracks()):
|
||||
if segment.duration < self.min_duration_on:
|
||||
del active[segment, track]
|
||||
|
||||
return active
|
||||
|
||||
|
||||
def merge_chunks(
|
||||
segments,
|
||||
chunk_length,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
edge_padding: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Merge operation described in whisper-x paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
|
||||
assert chunk_length > 0
|
||||
binarize = BinarizeVadScores(max_duration=chunk_length, onset=onset, offset=offset)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(
|
||||
SegmentX(
|
||||
max(0.0, speech_turn.start - edge_padding),
|
||||
speech_turn.end + edge_padding,
|
||||
"UNKNOWN",
|
||||
decoder_outputs = []
|
||||
for window in np.split(encoder_output, encoder_output.shape[1], axis=1):
|
||||
out, state = self.decoder_session.run(
|
||||
None, {"input": window.squeeze(1), "state": state}
|
||||
)
|
||||
) # 100ms edge padding to account for edge errors
|
||||
decoder_outputs.append(out)
|
||||
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
out = np.stack(decoder_outputs, axis=1).squeeze(-1)
|
||||
return out
|
||||
|
||||
|
||||
def merge_segments(segments_list, vad_options: VadOptions, sampling_rate: int = 16000):
|
||||
if not segments_list:
|
||||
return []
|
||||
|
||||
# Make sur the starting point is the start of the segment.
|
||||
curr_start = segments_list[0].start
|
||||
curr_end = 0
|
||||
seg_idxs = []
|
||||
merged_segments = []
|
||||
edge_padding = vad_options.speech_pad_ms * sampling_rate // 1000
|
||||
chunk_length = vad_options.max_speech_duration_s * sampling_rate
|
||||
|
||||
curr_start = segments_list[0]["start"]
|
||||
|
||||
for idx, seg in enumerate(segments_list):
|
||||
# if any segment start timing is less than previous segment end timing,
|
||||
# reset the edge padding. Similarly for end timing.
|
||||
if idx > 0:
|
||||
if seg.start < segments_list[idx - 1].end:
|
||||
seg.start += edge_padding
|
||||
if seg["start"] < segments_list[idx - 1]["end"]:
|
||||
seg["start"] += edge_padding
|
||||
if idx < len(segments_list) - 1:
|
||||
if seg.end > segments_list[idx + 1].start:
|
||||
seg.end -= edge_padding
|
||||
if seg["end"] > segments_list[idx + 1]["start"]:
|
||||
seg["end"] -= edge_padding
|
||||
|
||||
if seg.end - curr_start > chunk_length and curr_end - curr_start > 0:
|
||||
if seg["end"] - curr_start > chunk_length and curr_end - curr_start > 0:
|
||||
merged_segments.append(
|
||||
{
|
||||
"start": curr_start,
|
||||
@@ -579,12 +344,10 @@ def merge_chunks(
|
||||
"segments": seg_idxs,
|
||||
}
|
||||
)
|
||||
curr_start = seg.start
|
||||
curr_start = seg["start"]
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
curr_end = seg["end"]
|
||||
seg_idxs.append((seg["start"], seg["end"]))
|
||||
# add final
|
||||
merged_segments.append(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user