Refactor of language detection functions (#1146)

* Supported new options for batched transcriptions:
  * `language_detection_threshold`
  * `language_detection_segments`
* Updated `WhisperModel.detect_language` function to include the improved language detection from #732  and added docstrings, it's now used inside `transcribe` function.
* Removed the following functions as they are no longer needed:
  * `WhisperModel.detect_language_multi_segment` and its test
  * `BatchedInferencePipeline.get_language_and_tokenizer`
* Added tests for empty audios
This commit is contained in:
Mahmoud Ashraf
2024-11-16 12:53:07 +02:00
committed by GitHub
parent 53bbe54016
commit a6f8fbae00
3 changed files with 153 additions and 345 deletions

View File

@@ -164,17 +164,6 @@ segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here. segments = list(segments) # The transcription will actually run here.
``` ```
### Multi-Segment Language Detection
To directly use the model for improved language detection, the following code snippet can be used:
```python
from faster_whisper import WhisperModel
model = WhisperModel("turbo", device="cuda", compute_type="float16")
language_info = model.detect_language_multi_segment("audio.mp3")
```
### Batched Transcription ### Batched Transcription
The following code snippet illustrates how to run batched transcription on an example audio file. `BatchedInferencePipeline.transcribe` is a drop-in replacement for `WhisperModel.transcribe` The following code snippet illustrates how to run batched transcription on an example audio file. `BatchedInferencePipeline.transcribe` is a drop-in replacement for `WhisperModel.transcribe`

View File

@@ -2,10 +2,8 @@ import itertools
import json import json
import logging import logging
import os import os
import random
import zlib import zlib
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import signature from inspect import signature
from math import ceil from math import ceil
@@ -194,45 +192,11 @@ class BatchedInferencePipeline:
return segmented_outputs return segmented_outputs
def get_language_and_tokenizer(
self, audio, task: Optional[str] = None, language: Optional[str] = None
):
all_language_probs = None
language_probability = 1.0
if self.tokenizer is None:
if not language:
(
language,
language_probability,
all_language_probs,
) = self.model.detect_language(audio)
task = task or "transcribe"
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
if task is not None:
self.tokenizer.task = self.tokenizer.tokenizer.token_to_id(
f"<|{task}|>"
)
if language is not None:
self.tokenizer.language = self.tokenizer.tokenizer.token_to_id(
f"<|{language}|>"
)
self.tokenizer.language_code = language
return language, language_probability, task, all_language_probs
def transcribe( def transcribe(
self, self,
audio: Union[str, BinaryIO, np.ndarray], audio: Union[str, BinaryIO, np.ndarray],
language: Optional[str] = None, language: Optional[str] = None,
task: str = None, task: str = "transcribe",
log_progress: bool = False, log_progress: bool = False,
beam_size: int = 5, beam_size: int = 5,
best_of: int = 5, best_of: int = 5,
@@ -267,6 +231,8 @@ class BatchedInferencePipeline:
clip_timestamps: Optional[List[dict]] = None, clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16, batch_size: int = 16,
hotwords: Optional[str] = None, hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]: ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info. """transcribe audio in chunks in batched fashion and return with language info.
@@ -326,6 +292,9 @@ class BatchedInferencePipeline:
batch_size: the maximum number of parallel requests to model for decoding. batch_size: the maximum number of parallel requests to model for decoding.
hotwords: hotwords:
Hotwords/hint phrases to the model. Has no effect if prefix is not None. Hotwords/hint phrases to the model. Has no effect if prefix is not None.
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Static params: (Fixed for batched version) Static params: (Fixed for batched version)
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
@@ -390,28 +359,68 @@ class BatchedInferencePipeline:
"No clip timestamps found. " "No clip timestamps found. "
"Set 'vad_filter' to True or provide 'clip_timestamps'." "Set 'vad_filter' to True or provide 'clip_timestamps'."
) )
if self.model.model.is_multilingual:
language = language or self.preset_language
elif language != "en":
if language is not None:
self.model.logger.warning(
f"English-only model is used, but {language} language is"
" chosen, setting language to 'en'."
)
language = "en"
(
language,
language_probability,
task,
all_language_probs,
) = self.get_language_and_tokenizer(audio, task, language)
duration_after_vad = ( duration_after_vad = (
sum((segment["end"] - segment["start"]) for segment in clip_timestamps) sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
/ sampling_rate / sampling_rate
) )
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = (
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
if duration_after_vad
else []
)
all_language_probs = None
# detecting the language if not provided
if language is None:
if not self.model.model.is_multilingual:
language = "en"
language_probability = 1
else:
(
language,
language_probability,
all_language_probs,
) = self.model.detect_language(
features=np.concatenate(
features
+ [
np.full((self.model.model.n_mels, 1), -1.5, dtype="float32")
],
axis=1,
), # add a dummy feature to account for empty audio
language_detection_segments=language_detection_segments,
language_detection_threshold=language_detection_threshold,
)
self.model.logger.info(
"Detected language '%s' with probability %.2f",
language,
language_probability,
)
else:
if not self.model.model.is_multilingual and language != "en":
self.model.logger.warning(
"The current model is English-only but the language parameter is set to '%s'; "
"using 'en' instead." % language
)
language = "en"
language_probability = 1
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
features = (
np.stack([pad_or_trim(feature) for feature in features]) if features else []
)
# batched options: see the difference with default options in WhisperModel # batched options: see the difference with default options in WhisperModel
batched_options = TranscriptionOptions( batched_options = TranscriptionOptions(
beam_size=beam_size, beam_size=beam_size,
@@ -456,23 +465,6 @@ class BatchedInferencePipeline:
all_language_probs=all_language_probs, all_language_probs=all_language_probs,
) )
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = (
np.stack(
[
pad_or_trim(
self.model.feature_extractor(chunk)[
...,
: chunk.shape[0] // self.model.feature_extractor.hop_length,
]
)
for chunk in audio_chunks
]
)
if duration_after_vad
else []
)
segments = self._batched_segments_generator( segments = self._batched_segments_generator(
features, features,
chunks_metadata, chunks_metadata,
@@ -518,9 +510,6 @@ class BatchedInferencePipeline:
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
self.last_speech_timestamp = 0.0 self.last_speech_timestamp = 0.0
@@ -835,11 +824,6 @@ class WhisperModel:
language = "en" language = "en"
language_probability = 1 language_probability = 1
else: else:
if (
language_detection_segments is None
or language_detection_segments < 1
):
language_detection_segments = 1
start_timestamp = ( start_timestamp = (
float(clip_timestamps.split(",")[0]) float(clip_timestamps.split(",")[0])
if isinstance(clip_timestamps, str) if isinstance(clip_timestamps, str)
@@ -851,41 +835,15 @@ class WhisperModel:
if start_timestamp * self.frames_per_second < content_frames if start_timestamp * self.frames_per_second < content_frames
else 0 else 0
) )
end_frames = min( (
seek language,
+ self.feature_extractor.nb_max_frames language_probability,
* language_detection_segments, all_language_probs,
content_frames, ) = self.detect_language(
features=features[..., seek:],
language_detection_segments=language_detection_segments,
language_detection_threshold=language_detection_threshold,
) )
detected_language_info = {}
while seek <= end_frames:
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(pad_or_trim(segment))
# results is a list of tuple[str, float] with language names and
# probabilities.
results = self.model.detect_language(encoder_output)[0]
# Parse language names to strip out markers
all_language_probs = [
(token[2:-2], prob) for (token, prob) in results
]
# Get top language token and probability
language, language_probability = all_language_probs[0]
if language_probability > language_detection_threshold:
break
detected_language_info.setdefault(language, []).append(
language_probability
)
seek += segment.shape[-1]
else:
# If no language detected for all segments, the majority vote of the highest
# projected languages for all segments is used to determine the language.
language = max(
detected_language_info,
key=lambda lang: len(detected_language_info[lang]),
)
language_probability = max(detected_language_info[language])
self.logger.info( self.logger.info(
"Detected language '%s' with probability %.2f", "Detected language '%s' with probability %.2f",
@@ -1782,224 +1740,81 @@ class WhisperModel:
return encoder_output, output return encoder_output, output
def detect_language(self, audio: np.ndarray): def detect_language(
segment = self.feature_extractor(audio)[ self,
:, : self.feature_extractor.nb_max_frames audio: Optional[np.ndarray] = None,
features: Optional[np.ndarray] = None,
vad_filter: bool = False,
vad_parameters: Union[dict, VadOptions] = None,
language_detection_segments: int = 1,
language_detection_threshold: float = 0.5,
) -> Tuple[str, float, List[Tuple[str, float]]]:
"""
Use Whisper to detect the language of the input audio or features.
Arguments:
audio: Input audio signal, must be a 1D float array sampled at 16khz.
features: Input Mel spectrogram features, must be a float array with
shape (n_mels, n_frames), if `audio` is provided, the features will be ignored.
Either `audio` or `features` must be provided.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
language: Detected language.
languege_probability: Probability of the detected language.
all_language_probs: List of tuples with all language names and probabilities.
"""
assert (
audio is not None or features is not None
), "Either `audio` or `features` must be provided."
if audio is not None:
if vad_filter:
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)
audio = audio[
: language_detection_segments * self.feature_extractor.n_samples
]
features = self.feature_extractor(audio)
features = features[
..., : language_detection_segments * self.feature_extractor.nb_max_frames
] ]
encoder_output = self.encode(pad_or_trim(segment))
results = self.model.detect_language(encoder_output) detected_language_info = {}
language_token, language_probability = results[0][0] for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
language = language_token[2:-2] encoder_output = self.encode(
self.logger.info( pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames])
f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..." )
) # results is a list of tuple[str, float] with language names and probabilities.
all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]] results = self.model.detect_language(encoder_output)[0]
# Parse language names to strip out markers
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
# Get top language token and probability
language, language_probability = all_language_probs[0]
if language_probability > language_detection_threshold:
break
detected_language_info.setdefault(language, []).append(language_probability)
else:
# If no language detected for all segments, the majority vote of the highest
# projected languages for all segments is used to determine the language.
language = max(
detected_language_info,
key=lambda lang: len(detected_language_info[lang]),
)
language_probability = max(detected_language_info[language])
return language, language_probability, all_language_probs return language, language_probability, all_language_probs
def detect_language_multi_segment(
self, audio: Union[str, BinaryIO, np.ndarray], params: Optional[dict] = None
):
"""
Detect language based on N highly-confident segments of a language.
"""
# The threshold is used to decide if the audio is silence or not.
# The default is 0.02 (2.0%) i.e, if more than 2.0% of the audio is silent,
# the audio is considered as silence.
if not params:
params = {
"multilingual": False,
"speech_percentage_threshold": 0.02,
"language_detection_segments": 4,
"vad_filter": True,
"vad_min_silence_duration": 2500,
"language_threshold": 0.7,
}
if params.get("multilingual", False):
logging.warning(
"lang_id is not supported for multilingual audios, detecting the major language."
)
speech_percentage_threshold = params.get("speech_percentage_threshold", 0.02)
language_threshold = params.get("language_threshold", 0.7)
num_detection_segments = params.get("language_detection_segments", 4)
vad_filter_enabled = params.get("vad_filter", True)
vad_params = dict(
min_silence_duration_ms=params.get("vad_min_silence_duration", 2500)
)
if vad_filter_enabled:
vad_params = VadOptions(**vad_params)
# decode audio if it is not decoded already
sampling_rate = self.feature_extractor.sampling_rate
if not isinstance(audio, np.ndarray):
audio: np.ndarray = decode_audio(audio, sampling_rate=sampling_rate)
# calculate duration of audio as number of seconds
# audio.shape[0] is the number of samples in the audio
# sampling_rate is the number of samples per second
# if we divide the number of samples by the number of samples per second,
# we get the duration in seconds
duration = audio.shape[0] / sampling_rate
# Check if vad is enabled, and collect voiced segments
if vad_filter_enabled:
# get chunks of audio that contain speech
speech_chunks = get_speech_timestamps(audio, vad_params)
# merge chunks of audio that contain speech into a single array
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)
# calculate new duration of audio without silence
duration_vad = audio.shape[0] / sampling_rate
logging.debug(
f"Lang ID: VAD filter removed {duration - duration_vad} sec of audio"
)
# if the audio after VAD is less than 2% of the original audio, consider it as silence
if duration_vad / duration < speech_percentage_threshold:
return {"language_code": None, "language_confidence": 1.0}
# update duration to be the duration after VAD
duration = duration_vad
# if the duration of the audio is less than 1 second, consider it as silence
if duration < 1.0:
return {"language_code": None, "language_confidence": 1.0}
# number of feature frames in 30 seconds of audio is 3000
nb_max_frames = self.feature_extractor.nb_max_frames
# extract features from audio with padding (default)
features = self.feature_extractor(audio)
# number of segments in the audio
num_segments = features.shape[-1] // nb_max_frames
# more number of segments than possible with the duration of file
if num_detection_segments > num_segments:
logging.warning(
f"Lang ID: Can not have more segments, setting {num_segments} segments."
)
num_detection_segments = num_segments
# create a list of indices to randomly select segments from
indices = list(range(num_detection_segments))
# fix seed to get deterministic results
random.seed(0)
random.shuffle(indices)
detected_languages = []
all_language_probabilities = defaultdict(list)
confident_language_probabilities = defaultdict(list)
num_confident_segments_per_language = defaultdict(int)
# Iterate over the randomly selected indices of the segments.
#
# For each segment, extract features and detect language.
#
# If the language is confident, add it to the list of confident segments for that language.
#
# If the number of confident segments for a language
# is greater than or equal to the number of detection segments,
# return the language and the average probability of the language.
#
# If we are unable to get sufficient number of confident predcitions,
# return the most frequently detected language with maximum probability.
#
# We need to get sufficient number of confident predictions per language, not in total.
for i in indices:
segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames]
try:
encoder_output = self.encode(pad_or_trim(segment_features))
results = self.model.detect_language(encoder_output)[0]
except ValueError as e: # or RuntimeError
logging.error(f"Inference error:{e}")
# results is the list of classes (languages) and their probabilities (descending),
# for eg: [('<|de|>', 0.482177734375),('<|en|>', 0.283447265625),...]
# take top language token and probability
# and parse language token to strip out markers
# for eg: '<|de|>' -> 'de'
language_token = results[0][0]
language = language_token[2:-2]
language_probability = results[0][1]
detected_languages.append(language)
all_language_probabilities[language].append(language_probability)
# only consider if the language prediction is confident
if language_probability > language_threshold:
num_confident_segments_per_language[language] += 1
# Add language and probability to the list of languages when it is confident
confident_language_probabilities[language].append(language_probability)
# return the language when sufficient number of confident segments is achieved
if (
num_confident_segments_per_language[language]
>= num_detection_segments
):
# Considering the average probability of only confident segments
mean = sum(confident_language_probabilities[language]) / len(
confident_language_probabilities[language]
)
return {
"language_code": language,
"language_confidence": mean,
}
# if we are unable to get sufficient number of confident predictions,
# return the most frequently detected language.
# if there is a tie, return the one with maximum average probability.
counter = Counter(detected_languages)
# Define the key function to select frequent language with attached probabilities
def key_func(language):
# Calculate the frequency of the language
frequency = counter[language]
# Calculate the average probability of the language
prob_avg = sum(all_language_probabilities[language]) / len(
all_language_probabilities[language]
)
return frequency, prob_avg
if detected_languages:
# Use the key function to find the language with maximum frequency and probability
max_language = max(detected_languages, key=key_func)
max_probability = sum(all_language_probabilities[max_language]) / len(
all_language_probabilities[max_language]
)
# Do additional checks for silence for non-confident case
# calculate RMS amplitude and DC offset
dc_offset = audio.mean()
audio_minus_dc_offset = audio - dc_offset
is_silent = (
all(np.abs(audio) < 0.1)
or np.sqrt(np.mean(audio_minus_dc_offset**2)) < 0.01
)
if is_silent:
return {"language_code": None, "language_confidence": 1.0}
return {
"language_code": max_language,
"language_confidence": max_probability,
}
# Language is not detected for any segment and none of prev conditions met
return {"language_code": None, "language_confidence": 1.0}
def restore_speech_timestamps( def restore_speech_timestamps(
segments: Iterable[Segment], segments: Iterable[Segment],

View File

@@ -1,5 +1,7 @@
import os import os
import numpy as np
from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
from faster_whisper.tokenizer import Tokenizer from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import get_suppressed_tokens from faster_whisper.transcribe import get_suppressed_tokens
@@ -87,6 +89,15 @@ def test_batched_transcribe(physcisworks_path):
assert len(segments) > 7 assert len(segments) > 7
def test_empty_audio():
audio = np.asarray([], dtype="float32")
model = WhisperModel("tiny")
pipeline = BatchedInferencePipeline(model=model)
assert list(model.transcribe(audio)[0]) == []
assert list(pipeline.transcribe(audio)[0]) == []
model.detect_language(audio)
def test_prefix_with_timestamps(jfk_path): def test_prefix_with_timestamps(jfk_path):
model = WhisperModel("tiny") model = WhisperModel("tiny")
segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans") segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans")
@@ -147,13 +158,6 @@ def test_stereo_diarization(data_dir):
assert transcription == "The horizon seems extremely distant." assert transcription == "The horizon seems extremely distant."
def test_multisegment_lang_id(physcisworks_path):
model = WhisperModel("tiny")
language_info = model.detect_language_multi_segment(physcisworks_path)
assert language_info["language_code"] == "en"
assert language_info["language_confidence"] > 0.8
def test_suppressed_tokens_minus_1(): def test_suppressed_tokens_minus_1():
model = WhisperModel("tiny.en") model = WhisperModel("tiny.en")