mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-12 23:18:06 -05:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba812f55a2 | ||
|
|
44466c7535 | ||
|
|
e3e46675b2 | ||
|
|
14ad587c98 | ||
|
|
9090997d25 | ||
|
|
dea24cbcc6 | ||
|
|
14ba1051f3 | ||
|
|
c26d609974 | ||
|
|
4bd98d5c5b | ||
|
|
93001a9438 | ||
|
|
a0c3cb9802 | ||
|
|
fbeb1ba731 | ||
|
|
d3bfd0a305 | ||
|
|
43d4163fe0 | ||
|
|
700584b2e6 | ||
|
|
1383fd4d37 | ||
|
|
9e657b47cb | ||
|
|
11fd8ab301 | ||
|
|
95164297ff | ||
|
|
1b24f284c9 | ||
|
|
b568faec40 | ||
|
|
f32c0e8af3 | ||
|
|
8327d8cc64 | ||
|
|
22a5238b56 |
24
.github/workflows/ci.yml
vendored
24
.github/workflows/ci.yml
vendored
@@ -15,12 +15,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v4
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install module
|
||||
run: |
|
||||
@@ -45,12 +45,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v4
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install module
|
||||
run: |
|
||||
@@ -67,12 +67,12 @@ jobs:
|
||||
needs: [check-code-format, run-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v4
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
include faster_whisper/assets/silero_encoder_v5.onnx
|
||||
include faster_whisper/assets/silero_decoder_v5.onnx
|
||||
include faster_whisper/assets/silero_vad_v6.onnx
|
||||
include requirements.txt
|
||||
include requirements.conversion.txt
|
||||
|
||||
@@ -56,7 +56,7 @@ For reference, here's the time and memory usage that are required to transcribe
|
||||
|
||||
## Requirements
|
||||
|
||||
* Python 3.8 or greater
|
||||
* Python 3.10 or greater
|
||||
|
||||
Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package.
|
||||
|
||||
@@ -237,7 +237,7 @@ See more model and transcription options in the [`WhisperModel`](https://github.
|
||||
Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list!
|
||||
|
||||
|
||||
* [faster-whisper-server](https://github.com/fedirz/faster-whisper-server) is an OpenAI compatible server using `faster-whisper`. It's easily deployable with Docker, works with OpenAI SDKs/CLI, supports streaming, and live transcription.
|
||||
* [speaches](https://github.com/speaches-ai/speaches) is an OpenAI compatible server using `faster-whisper`. It's easily deployable with Docker, works with OpenAI SDKs/CLI, supports streaming, and live transcription.
|
||||
* [WhisperX](https://github.com/m-bain/whisperX) is an award-winning Python library that offers speaker diarization and accurate word-level timestamps using wav2vec2 alignment
|
||||
* [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper.
|
||||
* [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo.
|
||||
@@ -249,6 +249,8 @@ Here is a non exhaustive list of open-source projects using faster-whisper. Feel
|
||||
* [Whisper-Streaming](https://github.com/ufal/whisper_streaming) implements real-time mode for offline Whisper-like speech-to-text models with faster-whisper as the most recommended back-end. It implements a streaming policy with self-adaptive latency based on the actual source complexity, and demonstrates the state of the art.
|
||||
* [WhisperLive](https://github.com/collabora/WhisperLive) is a nearly-live implementation of OpenAI's Whisper which uses faster-whisper as the backend to transcribe audio in real-time.
|
||||
* [Faster-Whisper-Transcriber](https://github.com/BBC-Esq/ctranslate2-faster-whisper-transcriber) is a simple but reliable voice transcriber that provides a user-friendly interface.
|
||||
* [Open-dubbing](https://github.com/softcatala/open-dubbing) is open dubbing is an AI dubbing system which uses machine learning models to automatically translate and synchronize audio dialogue into different languages.
|
||||
* [Whisper-FastAPI](https://github.com/heimoshuiyu/whisper-fastapi) whisper-fastapi is a very simple script that provides an API backend compatible with OpenAI, HomeAssistant, and Konele (Android voice typing) formats.
|
||||
|
||||
## Model conversion
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
BIN
faster_whisper/assets/silero_vad_v6.onnx
Normal file
BIN
faster_whisper/assets/silero_vad_v6.onnx
Normal file
Binary file not shown.
@@ -67,6 +67,12 @@ class Tokenizer:
|
||||
def no_timestamps(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|notimestamps|>")
|
||||
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|nospeech|>") or self.tokenizer.token_to_id(
|
||||
"<|nocaptions|>"
|
||||
)
|
||||
|
||||
@property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.no_timestamps + 1
|
||||
|
||||
@@ -25,7 +25,6 @@ from faster_whisper.vad import (
|
||||
VadOptions,
|
||||
collect_chunks,
|
||||
get_speech_timestamps,
|
||||
merge_segments,
|
||||
)
|
||||
|
||||
|
||||
@@ -125,7 +124,7 @@ class BatchedInferencePipeline:
|
||||
segmented_outputs = []
|
||||
segment_sizes = []
|
||||
for chunk_metadata, output in zip(chunks_metadata, outputs):
|
||||
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
|
||||
duration = chunk_metadata["duration"]
|
||||
segment_size = int(ceil(duration) * self.model.frames_per_second)
|
||||
segment_sizes.append(segment_size)
|
||||
(
|
||||
@@ -135,7 +134,7 @@ class BatchedInferencePipeline:
|
||||
) = self.model._split_segments_by_timestamps(
|
||||
tokenizer=tokenizer,
|
||||
tokens=output["tokens"],
|
||||
time_offset=chunk_metadata["start_time"],
|
||||
time_offset=chunk_metadata["offset"],
|
||||
segment_size=segment_size,
|
||||
segment_duration=duration,
|
||||
seek=0,
|
||||
@@ -153,7 +152,7 @@ class BatchedInferencePipeline:
|
||||
tokenizer.decode(subsegment["tokens"])
|
||||
),
|
||||
seek=int(
|
||||
chunk_metadata["start_time"] * self.model.frames_per_second
|
||||
chunk_metadata["offset"] * self.model.frames_per_second
|
||||
),
|
||||
)
|
||||
for subsegment in subsegments
|
||||
@@ -388,6 +387,10 @@ class BatchedInferencePipeline:
|
||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
|
||||
self.model.logger.info(
|
||||
"Processing audio with duration %s", format_timestamp(duration)
|
||||
)
|
||||
|
||||
chunk_length = chunk_length or self.model.feature_extractor.chunk_length
|
||||
# if no segment split is provided, use vad_model and generate segments
|
||||
if not clip_timestamps:
|
||||
@@ -405,8 +408,7 @@ class BatchedInferencePipeline:
|
||||
**vad_parameters, max_speech_duration_s=chunk_length
|
||||
)
|
||||
|
||||
active_segments = get_speech_timestamps(audio, vad_parameters)
|
||||
clip_timestamps = merge_segments(active_segments, vad_parameters)
|
||||
clip_timestamps = get_speech_timestamps(audio, vad_parameters)
|
||||
# run the audio if it is less than 30 sec even without clip_timestamps
|
||||
elif duration < chunk_length:
|
||||
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
|
||||
@@ -416,12 +418,37 @@ class BatchedInferencePipeline:
|
||||
"Set 'vad_filter' to True or provide 'clip_timestamps'."
|
||||
)
|
||||
|
||||
audio_chunks, chunks_metadata = collect_chunks(
|
||||
audio, clip_timestamps, max_duration=chunk_length
|
||||
)
|
||||
|
||||
else:
|
||||
clip_timestamps = [
|
||||
{k: int(v * sampling_rate) for k, v in segment.items()}
|
||||
for segment in clip_timestamps
|
||||
]
|
||||
|
||||
audio_chunks, chunks_metadata = [], []
|
||||
for clip in clip_timestamps:
|
||||
audio_chunks.append(audio[clip["start"] : clip["end"]])
|
||||
chunks_metadata.append(
|
||||
{
|
||||
"offset": clip["start"] / sampling_rate,
|
||||
"duration": (clip["end"] - clip["start"]) / sampling_rate,
|
||||
"segments": [clip],
|
||||
}
|
||||
)
|
||||
|
||||
duration_after_vad = (
|
||||
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
|
||||
/ sampling_rate
|
||||
)
|
||||
|
||||
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
|
||||
self.model.logger.info(
|
||||
"VAD filter removed %s of audio",
|
||||
format_timestamp(duration - duration_after_vad),
|
||||
)
|
||||
|
||||
features = (
|
||||
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
|
||||
if duration_after_vad
|
||||
@@ -495,7 +522,11 @@ class BatchedInferencePipeline:
|
||||
initial_prompt=initial_prompt,
|
||||
prefix=prefix,
|
||||
suppress_blank=suppress_blank,
|
||||
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
||||
suppress_tokens=(
|
||||
get_suppressed_tokens(tokenizer, suppress_tokens)
|
||||
if suppress_tokens
|
||||
else suppress_tokens
|
||||
),
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
max_new_tokens=max_new_tokens,
|
||||
@@ -528,6 +559,7 @@ class BatchedInferencePipeline:
|
||||
options,
|
||||
log_progress,
|
||||
)
|
||||
segments = restore_speech_timestamps(segments, clip_timestamps, sampling_rate)
|
||||
|
||||
return segments, info
|
||||
|
||||
@@ -583,6 +615,8 @@ class WhisperModel:
|
||||
download_root: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
files: dict = None,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""Initializes the Whisper model.
|
||||
@@ -614,6 +648,11 @@ class WhisperModel:
|
||||
files: Load model files from the memory. This argument is a dictionary mapping file names
|
||||
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
||||
identifier for this model.
|
||||
revision:
|
||||
An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash.
|
||||
use_auth_token: HuggingFace authentication token or True to use the
|
||||
token stored by the HuggingFace config folder.
|
||||
"""
|
||||
self.logger = get_logger()
|
||||
|
||||
@@ -629,6 +668,8 @@ class WhisperModel:
|
||||
model_size_or_path,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=download_root,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
self.model = ctranslate2.models.Whisper(
|
||||
@@ -1737,7 +1778,7 @@ class WhisperModel:
|
||||
|
||||
Returns:
|
||||
language: Detected language.
|
||||
languege_probability: Probability of the detected language.
|
||||
language_probability: Probability of the detected language.
|
||||
all_language_probs: List of tuples with all language names and probabilities.
|
||||
"""
|
||||
assert (
|
||||
@@ -1810,7 +1851,7 @@ def restore_speech_timestamps(
|
||||
|
||||
else:
|
||||
segment.start = ts_map.get_original_time(segment.start)
|
||||
segment.end = ts_map.get_original_time(segment.end)
|
||||
segment.end = ts_map.get_original_time(segment.end, is_end=True)
|
||||
|
||||
yield segment
|
||||
|
||||
@@ -1845,6 +1886,7 @@ def get_suppressed_tokens(
|
||||
tokenizer.sot,
|
||||
tokenizer.sot_prev,
|
||||
tokenizer.sot_lm,
|
||||
tokenizer.no_speech,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
@@ -26,6 +26,7 @@ _MODELS = {
|
||||
"distil-medium.en": "Systran/faster-distil-whisper-medium.en",
|
||||
"distil-small.en": "Systran/faster-distil-whisper-small.en",
|
||||
"distil-large-v3": "Systran/faster-distil-whisper-large-v3",
|
||||
"distil-large-v3.5": "distil-whisper/distil-large-v3.5-ct2",
|
||||
"large-v3-turbo": "mobiuslabsgmbh/faster-whisper-large-v3-turbo",
|
||||
"turbo": "mobiuslabsgmbh/faster-whisper-large-v3-turbo",
|
||||
}
|
||||
@@ -51,6 +52,8 @@ def download_model(
|
||||
output_dir: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
||||
|
||||
@@ -65,6 +68,10 @@ def download_model(
|
||||
local_files_only: If True, avoid downloading the file and return the path to the local
|
||||
cached file if it exists.
|
||||
cache_dir: Path to the folder where cached files are stored.
|
||||
revision: An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash.
|
||||
use_auth_token: HuggingFace authentication token or True to use the
|
||||
token stored by the HuggingFace config folder.
|
||||
|
||||
Returns:
|
||||
The path to the downloaded model.
|
||||
@@ -94,6 +101,7 @@ def download_model(
|
||||
"local_files_only": local_files_only,
|
||||
"allow_patterns": allow_patterns,
|
||||
"tqdm_class": disabled_tqdm,
|
||||
"revision": revision,
|
||||
}
|
||||
|
||||
if output_dir is not None:
|
||||
@@ -103,6 +111,9 @@ def download_model(
|
||||
if cache_dir is not None:
|
||||
kwargs["cache_dir"] = cache_dir
|
||||
|
||||
if use_auth_token is not None:
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
try:
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
except (
|
||||
|
||||
@@ -16,14 +16,14 @@ class VadOptions:
|
||||
"""VAD options.
|
||||
|
||||
Attributes:
|
||||
onset: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
||||
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
||||
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
||||
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
offset: Silence threshold for determining the end of speech. If a probability is lower than
|
||||
the offset, it is always considered silence. Values higher than offset are only considered
|
||||
speech if the previous sample was classified as speech; otherwise, they are treated as
|
||||
silence. This parameter helps refine the detection of speech transitions, ensuring smoother
|
||||
segment boundaries.
|
||||
neg_threshold: Silence threshold for determining the end of speech. If a probability is lower
|
||||
than neg_threshold, it is always considered silence. Values higher than neg_threshold
|
||||
are only considered speech if the previous sample was classified as speech; otherwise,
|
||||
they are treated as silence. This parameter helps refine the detection of speech
|
||||
transitions, ensuring smoother segment boundaries.
|
||||
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
||||
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
||||
than max_speech_duration_s will be split at the timestamp of the last silence that
|
||||
@@ -34,8 +34,8 @@ class VadOptions:
|
||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||
"""
|
||||
|
||||
onset: float = 0.5
|
||||
offset: float = onset - 0.15
|
||||
threshold: float = 0.5
|
||||
neg_threshold: float = None
|
||||
min_speech_duration_ms: int = 0
|
||||
max_speech_duration_s: float = float("inf")
|
||||
min_silence_duration_ms: int = 2000
|
||||
@@ -62,7 +62,8 @@ def get_speech_timestamps(
|
||||
if vad_options is None:
|
||||
vad_options = VadOptions(**kwargs)
|
||||
|
||||
onset = vad_options.onset
|
||||
threshold = vad_options.threshold
|
||||
neg_threshold = vad_options.neg_threshold
|
||||
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
||||
max_speech_duration_s = vad_options.max_speech_duration_s
|
||||
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||
@@ -85,12 +86,13 @@ def get_speech_timestamps(
|
||||
padded_audio = np.pad(
|
||||
audio, (0, window_size_samples - audio.shape[0] % window_size_samples)
|
||||
)
|
||||
speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0)
|
||||
speech_probs = model(padded_audio)
|
||||
|
||||
triggered = False
|
||||
speeches = []
|
||||
current_speech = {}
|
||||
offset = vad_options.offset
|
||||
if neg_threshold is None:
|
||||
neg_threshold = max(threshold - 0.15, 0.01)
|
||||
|
||||
# to save potential segment end (and tolerate some silence)
|
||||
temp_end = 0
|
||||
@@ -98,12 +100,12 @@ def get_speech_timestamps(
|
||||
prev_end = next_start = 0
|
||||
|
||||
for i, speech_prob in enumerate(speech_probs):
|
||||
if (speech_prob >= onset) and temp_end:
|
||||
if (speech_prob >= threshold) and temp_end:
|
||||
temp_end = 0
|
||||
if next_start < prev_end:
|
||||
next_start = window_size_samples * i
|
||||
|
||||
if (speech_prob >= onset) and not triggered:
|
||||
if (speech_prob >= threshold) and not triggered:
|
||||
triggered = True
|
||||
current_speech["start"] = window_size_samples * i
|
||||
continue
|
||||
@@ -130,7 +132,7 @@ def get_speech_timestamps(
|
||||
triggered = False
|
||||
continue
|
||||
|
||||
if (speech_prob < offset) and triggered:
|
||||
if (speech_prob < neg_threshold) and triggered:
|
||||
if not temp_end:
|
||||
temp_end = window_size_samples * i
|
||||
# condition to avoid cutting in very short silence
|
||||
@@ -182,25 +184,62 @@ def get_speech_timestamps(
|
||||
|
||||
|
||||
def collect_chunks(
|
||||
audio: np.ndarray, chunks: List[dict], sampling_rate: int = 16000
|
||||
) -> Tuple[List[np.ndarray], List[Dict[str, int]]]:
|
||||
"""Collects audio chunks."""
|
||||
audio: np.ndarray,
|
||||
chunks: List[dict],
|
||||
sampling_rate: int = 16000,
|
||||
max_duration: float = float("inf"),
|
||||
) -> Tuple[List[np.ndarray], List[Dict[str, float]]]:
|
||||
"""This function merges the chunks of audio into chunks of max_duration (s) length."""
|
||||
if not chunks:
|
||||
chunk_metadata = {
|
||||
"start_time": 0,
|
||||
"end_time": 0,
|
||||
"offset": 0,
|
||||
"duration": 0,
|
||||
"segments": [],
|
||||
}
|
||||
return [np.array([], dtype=np.float32)], [chunk_metadata]
|
||||
|
||||
audio_chunks = []
|
||||
chunks_metadata = []
|
||||
|
||||
current_segments = []
|
||||
current_duration = 0
|
||||
total_duration = 0
|
||||
current_audio = np.array([], dtype=np.float32)
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_metadata = {
|
||||
"start_time": chunk["start"] / sampling_rate,
|
||||
"end_time": chunk["end"] / sampling_rate,
|
||||
}
|
||||
audio_chunks.append(audio[chunk["start"] : chunk["end"]])
|
||||
chunks_metadata.append(chunk_metadata)
|
||||
if (
|
||||
current_duration + chunk["end"] - chunk["start"]
|
||||
> max_duration * sampling_rate
|
||||
):
|
||||
audio_chunks.append(current_audio)
|
||||
chunk_metadata = {
|
||||
"offset": total_duration / sampling_rate,
|
||||
"duration": current_duration / sampling_rate,
|
||||
"segments": current_segments,
|
||||
}
|
||||
total_duration += current_duration
|
||||
chunks_metadata.append(chunk_metadata)
|
||||
|
||||
current_segments = []
|
||||
|
||||
current_audio = audio[chunk["start"] : chunk["end"]]
|
||||
current_duration = chunk["end"] - chunk["start"]
|
||||
else:
|
||||
current_segments.append(chunk)
|
||||
current_audio = np.concatenate(
|
||||
(current_audio, audio[chunk["start"] : chunk["end"]])
|
||||
)
|
||||
|
||||
current_duration += chunk["end"] - chunk["start"]
|
||||
|
||||
audio_chunks.append(current_audio)
|
||||
|
||||
chunk_metadata = {
|
||||
"offset": total_duration / sampling_rate,
|
||||
"duration": current_duration / sampling_rate,
|
||||
"segments": current_segments,
|
||||
}
|
||||
chunks_metadata.append(chunk_metadata)
|
||||
return audio_chunks, chunks_metadata
|
||||
|
||||
|
||||
@@ -227,15 +266,19 @@ class SpeechTimestampsMap:
|
||||
self,
|
||||
time: float,
|
||||
chunk_index: Optional[int] = None,
|
||||
is_end: bool = False,
|
||||
) -> float:
|
||||
if chunk_index is None:
|
||||
chunk_index = self.get_chunk_index(time)
|
||||
chunk_index = self.get_chunk_index(time, is_end)
|
||||
|
||||
total_silence_before = self.total_silence_before[chunk_index]
|
||||
return round(total_silence_before + time, self.time_precision)
|
||||
|
||||
def get_chunk_index(self, time: float) -> int:
|
||||
def get_chunk_index(self, time: float, is_end: bool = False) -> int:
|
||||
sample = int(time * self.sampling_rate)
|
||||
if sample in self.chunk_end_sample and is_end:
|
||||
return self.chunk_end_sample.index(sample)
|
||||
|
||||
return min(
|
||||
bisect.bisect(self.chunk_end_sample, sample),
|
||||
len(self.chunk_end_sample) - 1,
|
||||
@@ -245,13 +288,12 @@ class SpeechTimestampsMap:
|
||||
@functools.lru_cache
|
||||
def get_vad_model():
|
||||
"""Returns the VAD model instance."""
|
||||
encoder_path = os.path.join(get_assets_path(), "silero_encoder_v5.onnx")
|
||||
decoder_path = os.path.join(get_assets_path(), "silero_decoder_v5.onnx")
|
||||
return SileroVADModel(encoder_path, decoder_path)
|
||||
path = os.path.join(get_assets_path(), "silero_vad_v6.onnx")
|
||||
return SileroVADModel(path)
|
||||
|
||||
|
||||
class SileroVADModel:
|
||||
def __init__(self, encoder_path, decoder_path):
|
||||
def __init__(self, path):
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError as e:
|
||||
@@ -260,17 +302,13 @@ class SileroVADModel:
|
||||
) from e
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 0
|
||||
opts.intra_op_num_threads = 0
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
opts.enable_cpu_mem_arena = False
|
||||
opts.log_severity_level = 4
|
||||
|
||||
self.encoder_session = onnxruntime.InferenceSession(
|
||||
encoder_path,
|
||||
providers=["CPUExecutionProvider"],
|
||||
sess_options=opts,
|
||||
)
|
||||
self.decoder_session = onnxruntime.InferenceSession(
|
||||
decoder_path,
|
||||
self.session = onnxruntime.InferenceSession(
|
||||
path,
|
||||
providers=["CPUExecutionProvider"],
|
||||
sess_options=opts,
|
||||
)
|
||||
@@ -278,83 +316,36 @@ class SileroVADModel:
|
||||
def __call__(
|
||||
self, audio: np.ndarray, num_samples: int = 512, context_size_samples: int = 64
|
||||
):
|
||||
assert audio.ndim == 1, "Input should be a 1D array"
|
||||
assert (
|
||||
audio.ndim == 2
|
||||
), "Input should be a 2D array with size (batch_size, num_samples)"
|
||||
assert (
|
||||
audio.shape[1] % num_samples == 0
|
||||
audio.shape[0] % num_samples == 0
|
||||
), "Input size should be a multiple of num_samples"
|
||||
|
||||
batch_size = audio.shape[0]
|
||||
|
||||
state = np.zeros((2, batch_size, 128), dtype="float32")
|
||||
h = np.zeros((1, 1, 128), dtype="float32")
|
||||
c = np.zeros((1, 1, 128), dtype="float32")
|
||||
context = np.zeros(
|
||||
(batch_size, context_size_samples),
|
||||
(1, context_size_samples),
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
batched_audio = audio.reshape(batch_size, -1, num_samples)
|
||||
batched_audio = audio.reshape(-1, num_samples)
|
||||
context = batched_audio[..., -context_size_samples:]
|
||||
context[:, -1] = 0
|
||||
context = np.roll(context, 1, 1)
|
||||
batched_audio = np.concatenate([context, batched_audio], 2)
|
||||
context[-1] = 0
|
||||
context = np.roll(context, 1, 0)
|
||||
batched_audio = np.concatenate([context, batched_audio], 1)
|
||||
|
||||
batched_audio = batched_audio.reshape(-1, num_samples + context_size_samples)
|
||||
|
||||
encoder_output = self.encoder_session.run(None, {"input": batched_audio})[0]
|
||||
encoder_output = encoder_output.reshape(batch_size, -1, 128)
|
||||
|
||||
decoder_outputs = []
|
||||
for window in np.split(encoder_output, encoder_output.shape[1], axis=1):
|
||||
out, state = self.decoder_session.run(
|
||||
None, {"input": window.squeeze(1), "state": state}
|
||||
encoder_batch_size = 10000
|
||||
num_segments = batched_audio.shape[0]
|
||||
outputs = []
|
||||
for i in range(0, num_segments, encoder_batch_size):
|
||||
output, h, c = self.session.run(
|
||||
None,
|
||||
{"input": batched_audio[i : i + encoder_batch_size], "h": h, "c": c},
|
||||
)
|
||||
decoder_outputs.append(out)
|
||||
outputs.append(output)
|
||||
|
||||
out = np.concatenate(outputs, axis=0)
|
||||
|
||||
out = np.stack(decoder_outputs, axis=1).squeeze(-1)
|
||||
return out
|
||||
|
||||
|
||||
def merge_segments(segments_list, vad_options: VadOptions, sampling_rate: int = 16000):
|
||||
if not segments_list:
|
||||
return []
|
||||
|
||||
curr_end = 0
|
||||
seg_idxs = []
|
||||
merged_segments = []
|
||||
edge_padding = vad_options.speech_pad_ms * sampling_rate // 1000
|
||||
chunk_length = vad_options.max_speech_duration_s * sampling_rate
|
||||
|
||||
curr_start = segments_list[0]["start"]
|
||||
|
||||
for idx, seg in enumerate(segments_list):
|
||||
# if any segment start timing is less than previous segment end timing,
|
||||
# reset the edge padding. Similarly for end timing.
|
||||
if idx > 0:
|
||||
if seg["start"] < segments_list[idx - 1]["end"]:
|
||||
seg["start"] += edge_padding
|
||||
if idx < len(segments_list) - 1:
|
||||
if seg["end"] > segments_list[idx + 1]["start"]:
|
||||
seg["end"] -= edge_padding
|
||||
|
||||
if seg["end"] - curr_start > chunk_length and curr_end - curr_start > 0:
|
||||
merged_segments.append(
|
||||
{
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
}
|
||||
)
|
||||
curr_start = seg["start"]
|
||||
seg_idxs = []
|
||||
curr_end = seg["end"]
|
||||
seg_idxs.append((seg["start"], seg["end"]))
|
||||
# add final
|
||||
merged_segments.append(
|
||||
{
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
}
|
||||
)
|
||||
return merged_segments
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Version information."""
|
||||
|
||||
__version__ = "1.1.0"
|
||||
__version__ = "1.2.0"
|
||||
|
||||
5
setup.py
5
setup.py
@@ -45,14 +45,13 @@ setup(
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
keywords="openai whisper speech ctranslate2 inference quantization transformer",
|
||||
python_requires=">=3.8",
|
||||
python_requires=">=3.10",
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
"conversion": conversion_requires,
|
||||
|
||||
@@ -98,6 +98,7 @@ def test_suppressed_tokens_minus_1():
|
||||
50358,
|
||||
50359,
|
||||
50360,
|
||||
50361,
|
||||
)
|
||||
|
||||
|
||||
@@ -106,7 +107,7 @@ def test_suppressed_tokens_minus_value():
|
||||
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||
tokens = get_suppressed_tokens(tokenizer, [13])
|
||||
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)
|
||||
assert tokens == (13, 50257, 50357, 50358, 50359, 50360, 50361)
|
||||
|
||||
|
||||
def test_split_on_unicode():
|
||||
|
||||
@@ -71,7 +71,7 @@ def test_batched_transcribe(physcisworks_path):
|
||||
{"start": segment.start, "end": segment.end, "text": segment.text}
|
||||
)
|
||||
# number of near 30 sec segments
|
||||
assert len(segments) == 7
|
||||
assert len(segments) == 6
|
||||
|
||||
result, info = batched_model.transcribe(
|
||||
physcisworks_path,
|
||||
@@ -269,3 +269,24 @@ def test_monotonic_timestamps(physcisworks_path):
|
||||
assert word.start <= word.end
|
||||
assert word.end <= segments[i].end
|
||||
assert segments[-1].end <= info.duration
|
||||
|
||||
|
||||
def test_cliptimestamps_segments(jfk_path):
|
||||
model = WhisperModel("tiny")
|
||||
pipeline = BatchedInferencePipeline(model=model)
|
||||
|
||||
audio = decode_audio(jfk_path)
|
||||
audio = np.concatenate([audio, audio])
|
||||
clip_timestamps = [{"start": 0.0, "end": 11.0}, {"start": 11.0, "end": 22.0}]
|
||||
|
||||
segments, info = pipeline.transcribe(audio, clip_timestamps=clip_timestamps)
|
||||
segments = list(segments)
|
||||
|
||||
assert len(segments) == 2
|
||||
for segment, clip in zip(segments, clip_timestamps):
|
||||
assert segment.start == clip["start"]
|
||||
assert segment.end == clip["end"]
|
||||
assert segment.text == (
|
||||
" And so my fellow Americans ask not what your country can do for you, "
|
||||
"ask what you can do for your country."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user