mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-13 07:27:55 -05:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed9a06cd89 | ||
|
|
2eeafe05de | ||
|
|
cf42429f96 | ||
|
|
65882eee9f | ||
|
|
409a6919f9 | ||
|
|
00a5b26b1f | ||
|
|
9090997d25 | ||
|
|
dea24cbcc6 | ||
|
|
14ba1051f3 | ||
|
|
c26d609974 | ||
|
|
4bd98d5c5b |
@@ -1,4 +1,3 @@
|
||||
include faster_whisper/assets/silero_encoder_v5.onnx
|
||||
include faster_whisper/assets/silero_decoder_v5.onnx
|
||||
include faster_whisper/assets/silero_vad_v6.onnx
|
||||
include requirements.txt
|
||||
include requirements.conversion.txt
|
||||
|
||||
@@ -250,6 +250,7 @@ Here is a non exhaustive list of open-source projects using faster-whisper. Feel
|
||||
* [WhisperLive](https://github.com/collabora/WhisperLive) is a nearly-live implementation of OpenAI's Whisper which uses faster-whisper as the backend to transcribe audio in real-time.
|
||||
* [Faster-Whisper-Transcriber](https://github.com/BBC-Esq/ctranslate2-faster-whisper-transcriber) is a simple but reliable voice transcriber that provides a user-friendly interface.
|
||||
* [Open-dubbing](https://github.com/softcatala/open-dubbing) is open dubbing is an AI dubbing system which uses machine learning models to automatically translate and synchronize audio dialogue into different languages.
|
||||
* [Whisper-FastAPI](https://github.com/heimoshuiyu/whisper-fastapi) whisper-fastapi is a very simple script that provides an API backend compatible with OpenAI, HomeAssistant, and Konele (Android voice typing) formats.
|
||||
|
||||
## Model conversion
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
BIN
faster_whisper/assets/silero_vad_v6.onnx
Normal file
BIN
faster_whisper/assets/silero_vad_v6.onnx
Normal file
Binary file not shown.
@@ -67,6 +67,12 @@ class Tokenizer:
|
||||
def no_timestamps(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|notimestamps|>")
|
||||
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self.tokenizer.token_to_id("<|nospeech|>") or self.tokenizer.token_to_id(
|
||||
"<|nocaptions|>"
|
||||
)
|
||||
|
||||
@property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.no_timestamps + 1
|
||||
|
||||
@@ -417,15 +417,38 @@ class BatchedInferencePipeline:
|
||||
"No clip timestamps found. "
|
||||
"Set 'vad_filter' to True or provide 'clip_timestamps'."
|
||||
)
|
||||
|
||||
clip_timestamps_provided = False
|
||||
audio_chunks, chunks_metadata = collect_chunks(
|
||||
audio, clip_timestamps, max_duration=chunk_length
|
||||
)
|
||||
|
||||
else:
|
||||
clip_timestamps_provided = True
|
||||
clip_timestamps = [
|
||||
{k: int(v * sampling_rate) for k, v in segment.items()}
|
||||
for segment in clip_timestamps
|
||||
]
|
||||
|
||||
audio_chunks, chunks_metadata = collect_chunks(
|
||||
audio, clip_timestamps, max_duration=chunk_length
|
||||
)
|
||||
audio_chunks, chunks_metadata = [], []
|
||||
for i, clip in enumerate(clip_timestamps):
|
||||
audio_chunks.append(audio[clip["start"] : clip["end"]])
|
||||
|
||||
clip_duration = (clip["end"] - clip["start"]) / sampling_rate
|
||||
if clip_duration > 30:
|
||||
self.model.logger.warning(
|
||||
"Segment %d is longer than 30 seconds, "
|
||||
"only the first 30 seconds will be transcribed",
|
||||
i,
|
||||
)
|
||||
|
||||
chunks_metadata.append(
|
||||
{
|
||||
"offset": clip["start"] / sampling_rate,
|
||||
"duration": clip_duration,
|
||||
"segments": [clip],
|
||||
}
|
||||
)
|
||||
|
||||
duration_after_vad = (
|
||||
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
|
||||
@@ -547,7 +570,10 @@ class BatchedInferencePipeline:
|
||||
options,
|
||||
log_progress,
|
||||
)
|
||||
segments = restore_speech_timestamps(segments, clip_timestamps, sampling_rate)
|
||||
if not clip_timestamps_provided:
|
||||
segments = restore_speech_timestamps(
|
||||
segments, clip_timestamps, sampling_rate
|
||||
)
|
||||
|
||||
return segments, info
|
||||
|
||||
@@ -1766,7 +1792,7 @@ class WhisperModel:
|
||||
|
||||
Returns:
|
||||
language: Detected language.
|
||||
languege_probability: Probability of the detected language.
|
||||
language_probability: Probability of the detected language.
|
||||
all_language_probs: List of tuples with all language names and probabilities.
|
||||
"""
|
||||
assert (
|
||||
@@ -1874,6 +1900,7 @@ def get_suppressed_tokens(
|
||||
tokenizer.sot,
|
||||
tokenizer.sot_prev,
|
||||
tokenizer.sot_lm,
|
||||
tokenizer.no_speech,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -106,7 +105,6 @@ def download_model(
|
||||
|
||||
if output_dir is not None:
|
||||
kwargs["local_dir"] = output_dir
|
||||
kwargs["local_dir_use_symlinks"] = False
|
||||
|
||||
if cache_dir is not None:
|
||||
kwargs["cache_dir"] = cache_dir
|
||||
@@ -114,24 +112,7 @@ def download_model(
|
||||
if use_auth_token is not None:
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
try:
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
except (
|
||||
huggingface_hub.utils.HfHubHTTPError,
|
||||
requests.exceptions.ConnectionError,
|
||||
) as exception:
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
||||
repo_id,
|
||||
exception,
|
||||
)
|
||||
logger.warning(
|
||||
"Trying to load the model directly from the local cache, if it exists."
|
||||
)
|
||||
|
||||
kwargs["local_files_only"] = True
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
|
||||
@@ -27,11 +27,15 @@ class VadOptions:
|
||||
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
||||
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
||||
than max_speech_duration_s will be split at the timestamp of the last silence that
|
||||
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
||||
split aggressively just before max_speech_duration_s.
|
||||
lasts more than min_silence_at_max_speech (if any), to prevent aggressive cutting.
|
||||
Otherwise, they will be split aggressively just before max_speech_duration_s.
|
||||
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
||||
before separating it
|
||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||
min_silence_at_max_speech: Minimum silence duration in ms which is used to avoid abrupt cuts
|
||||
when max_speech_duration_s is reached.
|
||||
use_max_poss_sil_at_max_speech: Whether to use the maximum possible silence at
|
||||
max_speech_duration_s or not. If not, the last silence is used.
|
||||
"""
|
||||
|
||||
threshold: float = 0.5
|
||||
@@ -40,6 +44,8 @@ class VadOptions:
|
||||
max_speech_duration_s: float = float("inf")
|
||||
min_silence_duration_ms: int = 2000
|
||||
speech_pad_ms: int = 400
|
||||
min_silence_at_max_speech: int = 98
|
||||
use_max_poss_sil_at_max_speech: bool = True
|
||||
|
||||
|
||||
def get_speech_timestamps(
|
||||
@@ -69,6 +75,9 @@ def get_speech_timestamps(
|
||||
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||
window_size_samples = 512
|
||||
speech_pad_ms = vad_options.speech_pad_ms
|
||||
min_silence_at_max_speech = vad_options.min_silence_at_max_speech
|
||||
use_max_poss_sil_at_max_speech = vad_options.use_max_poss_sil_at_max_speech
|
||||
|
||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
max_speech_samples = (
|
||||
@@ -77,7 +86,7 @@ def get_speech_timestamps(
|
||||
- 2 * speech_pad_samples
|
||||
)
|
||||
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
||||
min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
|
||||
|
||||
audio_length_samples = len(audio)
|
||||
|
||||
@@ -86,11 +95,13 @@ def get_speech_timestamps(
|
||||
padded_audio = np.pad(
|
||||
audio, (0, window_size_samples - audio.shape[0] % window_size_samples)
|
||||
)
|
||||
speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0)
|
||||
speech_probs = model(padded_audio)
|
||||
|
||||
triggered = False
|
||||
speeches = []
|
||||
current_speech = {}
|
||||
possible_ends = []
|
||||
|
||||
if neg_threshold is None:
|
||||
neg_threshold = max(threshold - 0.15, 0.01)
|
||||
|
||||
@@ -100,45 +111,67 @@ def get_speech_timestamps(
|
||||
prev_end = next_start = 0
|
||||
|
||||
for i, speech_prob in enumerate(speech_probs):
|
||||
cur_sample = window_size_samples * i
|
||||
|
||||
if (speech_prob >= threshold) and temp_end:
|
||||
sil_dur = cur_sample - temp_end
|
||||
if sil_dur > min_silence_samples_at_max_speech:
|
||||
possible_ends.append((temp_end, sil_dur))
|
||||
temp_end = 0
|
||||
if next_start < prev_end:
|
||||
next_start = window_size_samples * i
|
||||
next_start = cur_sample
|
||||
|
||||
if (speech_prob >= threshold) and not triggered:
|
||||
triggered = True
|
||||
current_speech["start"] = window_size_samples * i
|
||||
current_speech["start"] = cur_sample
|
||||
continue
|
||||
|
||||
if (
|
||||
triggered
|
||||
and (window_size_samples * i) - current_speech["start"] > max_speech_samples
|
||||
):
|
||||
if prev_end:
|
||||
if triggered and (cur_sample - current_speech["start"] > max_speech_samples):
|
||||
if use_max_poss_sil_at_max_speech and possible_ends:
|
||||
prev_end, dur = max(possible_ends, key=lambda x: x[1])
|
||||
current_speech["end"] = prev_end
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
# previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||
if next_start < prev_end:
|
||||
triggered = False
|
||||
else:
|
||||
next_start = prev_end + dur
|
||||
|
||||
if next_start < prev_end + cur_sample:
|
||||
current_speech["start"] = next_start
|
||||
else:
|
||||
triggered = False
|
||||
prev_end = next_start = temp_end = 0
|
||||
possible_ends = []
|
||||
else:
|
||||
current_speech["end"] = window_size_samples * i
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
prev_end = next_start = temp_end = 0
|
||||
triggered = False
|
||||
continue
|
||||
if prev_end:
|
||||
current_speech["end"] = prev_end
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
if next_start < prev_end:
|
||||
triggered = False
|
||||
else:
|
||||
current_speech["start"] = next_start
|
||||
prev_end = next_start = temp_end = 0
|
||||
possible_ends = []
|
||||
else:
|
||||
current_speech["end"] = cur_sample
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
prev_end = next_start = temp_end = 0
|
||||
triggered = False
|
||||
possible_ends = []
|
||||
continue
|
||||
|
||||
if (speech_prob < neg_threshold) and triggered:
|
||||
if not temp_end:
|
||||
temp_end = window_size_samples * i
|
||||
# condition to avoid cutting in very short silence
|
||||
if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
|
||||
temp_end = cur_sample
|
||||
sil_dur_now = cur_sample - temp_end
|
||||
|
||||
if (
|
||||
not use_max_poss_sil_at_max_speech
|
||||
and sil_dur_now > min_silence_samples_at_max_speech
|
||||
):
|
||||
prev_end = temp_end
|
||||
if (window_size_samples * i) - temp_end < min_silence_samples:
|
||||
|
||||
if sil_dur_now < min_silence_samples:
|
||||
continue
|
||||
else:
|
||||
current_speech["end"] = temp_end
|
||||
@@ -149,6 +182,7 @@ def get_speech_timestamps(
|
||||
current_speech = {}
|
||||
prev_end = next_start = temp_end = 0
|
||||
triggered = False
|
||||
possible_ends = []
|
||||
continue
|
||||
|
||||
if (
|
||||
@@ -288,13 +322,12 @@ class SpeechTimestampsMap:
|
||||
@functools.lru_cache
|
||||
def get_vad_model():
|
||||
"""Returns the VAD model instance."""
|
||||
encoder_path = os.path.join(get_assets_path(), "silero_encoder_v5.onnx")
|
||||
decoder_path = os.path.join(get_assets_path(), "silero_decoder_v5.onnx")
|
||||
return SileroVADModel(encoder_path, decoder_path)
|
||||
path = os.path.join(get_assets_path(), "silero_vad_v6.onnx")
|
||||
return SileroVADModel(path)
|
||||
|
||||
|
||||
class SileroVADModel:
|
||||
def __init__(self, encoder_path, decoder_path):
|
||||
def __init__(self, path):
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError as e:
|
||||
@@ -308,13 +341,8 @@ class SileroVADModel:
|
||||
opts.enable_cpu_mem_arena = False
|
||||
opts.log_severity_level = 4
|
||||
|
||||
self.encoder_session = onnxruntime.InferenceSession(
|
||||
encoder_path,
|
||||
providers=["CPUExecutionProvider"],
|
||||
sess_options=opts,
|
||||
)
|
||||
self.decoder_session = onnxruntime.InferenceSession(
|
||||
decoder_path,
|
||||
self.session = onnxruntime.InferenceSession(
|
||||
path,
|
||||
providers=["CPUExecutionProvider"],
|
||||
sess_options=opts,
|
||||
)
|
||||
@@ -322,47 +350,36 @@ class SileroVADModel:
|
||||
def __call__(
|
||||
self, audio: np.ndarray, num_samples: int = 512, context_size_samples: int = 64
|
||||
):
|
||||
assert audio.ndim == 1, "Input should be a 1D array"
|
||||
assert (
|
||||
audio.ndim == 2
|
||||
), "Input should be a 2D array with size (batch_size, num_samples)"
|
||||
assert (
|
||||
audio.shape[1] % num_samples == 0
|
||||
audio.shape[0] % num_samples == 0
|
||||
), "Input size should be a multiple of num_samples"
|
||||
|
||||
batch_size = audio.shape[0]
|
||||
|
||||
state = np.zeros((2, batch_size, 128), dtype="float32")
|
||||
h = np.zeros((1, 1, 128), dtype="float32")
|
||||
c = np.zeros((1, 1, 128), dtype="float32")
|
||||
context = np.zeros(
|
||||
(batch_size, context_size_samples),
|
||||
(1, context_size_samples),
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
batched_audio = audio.reshape(batch_size, -1, num_samples)
|
||||
batched_audio = audio.reshape(-1, num_samples)
|
||||
context = batched_audio[..., -context_size_samples:]
|
||||
context[:, -1] = 0
|
||||
context = np.roll(context, 1, 1)
|
||||
batched_audio = np.concatenate([context, batched_audio], 2)
|
||||
context[-1] = 0
|
||||
context = np.roll(context, 1, 0)
|
||||
batched_audio = np.concatenate([context, batched_audio], 1)
|
||||
|
||||
batched_audio = batched_audio.reshape(-1, num_samples + context_size_samples)
|
||||
|
||||
encoder_batch_size = 10000
|
||||
num_segments = batched_audio.shape[0]
|
||||
encoder_outputs = []
|
||||
outputs = []
|
||||
for i in range(0, num_segments, encoder_batch_size):
|
||||
encoder_output = self.encoder_session.run(
|
||||
None, {"input": batched_audio[i : i + encoder_batch_size]}
|
||||
)[0]
|
||||
encoder_outputs.append(encoder_output)
|
||||
|
||||
encoder_output = np.concatenate(encoder_outputs, axis=0)
|
||||
encoder_output = encoder_output.reshape(batch_size, -1, 128)
|
||||
|
||||
decoder_outputs = []
|
||||
for window in np.split(encoder_output, encoder_output.shape[1], axis=1):
|
||||
out, state = self.decoder_session.run(
|
||||
None, {"input": window.squeeze(1), "state": state}
|
||||
output, h, c = self.session.run(
|
||||
None,
|
||||
{"input": batched_audio[i : i + encoder_batch_size], "h": h, "c": c},
|
||||
)
|
||||
decoder_outputs.append(out)
|
||||
outputs.append(output)
|
||||
|
||||
out = np.concatenate(outputs, axis=0)
|
||||
|
||||
out = np.stack(decoder_outputs, axis=1).squeeze(-1)
|
||||
return out
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Version information."""
|
||||
|
||||
__version__ = "1.2.0"
|
||||
__version__ = "1.2.1"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ctranslate2>=4.0,<5
|
||||
huggingface_hub>=0.13
|
||||
huggingface_hub>=0.23
|
||||
tokenizers>=0.13,<1
|
||||
onnxruntime>=1.14,<2
|
||||
av>=11
|
||||
tqdm
|
||||
tqdm
|
||||
|
||||
@@ -98,6 +98,7 @@ def test_suppressed_tokens_minus_1():
|
||||
50358,
|
||||
50359,
|
||||
50360,
|
||||
50361,
|
||||
)
|
||||
|
||||
|
||||
@@ -106,7 +107,7 @@ def test_suppressed_tokens_minus_value():
|
||||
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||
tokens = get_suppressed_tokens(tokenizer, [13])
|
||||
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)
|
||||
assert tokens == (13, 50257, 50357, 50358, 50359, 50360, 50361)
|
||||
|
||||
|
||||
def test_split_on_unicode():
|
||||
|
||||
@@ -245,7 +245,7 @@ def test_transcribe_signature():
|
||||
|
||||
|
||||
def test_monotonic_timestamps(physcisworks_path):
|
||||
model = WhisperModel("tiny")
|
||||
model = WhisperModel("base")
|
||||
pipeline = BatchedInferencePipeline(model=model)
|
||||
|
||||
segments, info = model.transcribe(physcisworks_path, word_timestamps=True)
|
||||
@@ -269,3 +269,47 @@ def test_monotonic_timestamps(physcisworks_path):
|
||||
assert word.start <= word.end
|
||||
assert word.end <= segments[i].end
|
||||
assert segments[-1].end <= info.duration
|
||||
|
||||
|
||||
def test_cliptimestamps_segments(jfk_path):
|
||||
model = WhisperModel("tiny")
|
||||
pipeline = BatchedInferencePipeline(model=model)
|
||||
|
||||
audio = decode_audio(jfk_path)
|
||||
audio = np.concatenate([audio, audio])
|
||||
clip_timestamps = [{"start": 0.0, "end": 11.0}, {"start": 11.0, "end": 22.0}]
|
||||
|
||||
segments, info = pipeline.transcribe(audio, clip_timestamps=clip_timestamps)
|
||||
segments = list(segments)
|
||||
|
||||
assert len(segments) == 2
|
||||
for segment, clip in zip(segments, clip_timestamps):
|
||||
assert segment.start == clip["start"]
|
||||
assert segment.end == clip["end"]
|
||||
assert segment.text == (
|
||||
" And so my fellow Americans ask not what your country can do for you, "
|
||||
"ask what you can do for your country."
|
||||
)
|
||||
|
||||
|
||||
def test_cliptimestamps_timings(physcisworks_path):
|
||||
model = WhisperModel("tiny")
|
||||
pipeline = BatchedInferencePipeline(model=model)
|
||||
|
||||
audio = decode_audio(physcisworks_path)
|
||||
clip_timestamps = [{"start": 0.0, "end": 5.0}, {"start": 6.0, "end": 15.0}]
|
||||
transcripts = [
|
||||
" Now I want to return to the conservation of mechanical energy.",
|
||||
(
|
||||
" I have here a pendulum. I have an object that weighs 15 kilograms"
|
||||
" and I can lift it up one meter, which I have done now."
|
||||
),
|
||||
]
|
||||
segments, info = pipeline.transcribe(audio, clip_timestamps=clip_timestamps)
|
||||
segments = list(segments)
|
||||
|
||||
assert len(segments) == 2
|
||||
for segment, clip, transcript in zip(segments, clip_timestamps, transcripts):
|
||||
assert clip["start"] == segment.start
|
||||
assert clip["end"] == segment.end
|
||||
assert segment.text == transcript
|
||||
|
||||
Reference in New Issue
Block a user