Refactor of language detection functions (#1146)

* Supported new options for batched transcriptions:
  * `language_detection_threshold`
  * `language_detection_segments`
* Updated `WhisperModel.detect_language` function to include the improved language detection from #732  and added docstrings, it's now used inside `transcribe` function.
* Removed the following functions as they are no longer needed:
  * `WhisperModel.detect_language_multi_segment` and its test
  * `BatchedInferencePipeline.get_language_and_tokenizer`
* Added tests for empty audios
This commit is contained in:
Mahmoud Ashraf
2024-11-16 12:53:07 +02:00
committed by GitHub
parent 53bbe54016
commit a6f8fbae00
3 changed files with 153 additions and 345 deletions

View File

@@ -164,17 +164,6 @@ segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here.
```
### Multi-Segment Language Detection
To directly use the model for improved language detection, the following code snippet can be used:
```python
from faster_whisper import WhisperModel
model = WhisperModel("turbo", device="cuda", compute_type="float16")
language_info = model.detect_language_multi_segment("audio.mp3")
```
### Batched Transcription
The following code snippet illustrates how to run batched transcription on an example audio file. `BatchedInferencePipeline.transcribe` is a drop-in replacement for `WhisperModel.transcribe`

View File

@@ -2,10 +2,8 @@ import itertools
import json
import logging
import os
import random
import zlib
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
@@ -194,45 +192,11 @@ class BatchedInferencePipeline:
return segmented_outputs
def get_language_and_tokenizer(
self, audio, task: Optional[str] = None, language: Optional[str] = None
):
all_language_probs = None
language_probability = 1.0
if self.tokenizer is None:
if not language:
(
language,
language_probability,
all_language_probs,
) = self.model.detect_language(audio)
task = task or "transcribe"
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
if task is not None:
self.tokenizer.task = self.tokenizer.tokenizer.token_to_id(
f"<|{task}|>"
)
if language is not None:
self.tokenizer.language = self.tokenizer.tokenizer.token_to_id(
f"<|{language}|>"
)
self.tokenizer.language_code = language
return language, language_probability, task, all_language_probs
def transcribe(
self,
audio: Union[str, BinaryIO, np.ndarray],
language: Optional[str] = None,
task: str = None,
task: str = "transcribe",
log_progress: bool = False,
beam_size: int = 5,
best_of: int = 5,
@@ -267,6 +231,8 @@ class BatchedInferencePipeline:
clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info.
@@ -326,6 +292,9 @@ class BatchedInferencePipeline:
batch_size: the maximum number of parallel requests to model for decoding.
hotwords:
Hotwords/hint phrases to the model. Has no effect if prefix is not None.
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Static params: (Fixed for batched version)
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
@@ -390,28 +359,68 @@ class BatchedInferencePipeline:
"No clip timestamps found. "
"Set 'vad_filter' to True or provide 'clip_timestamps'."
)
if self.model.model.is_multilingual:
language = language or self.preset_language
elif language != "en":
if language is not None:
self.model.logger.warning(
f"English-only model is used, but {language} language is"
" chosen, setting language to 'en'."
)
language = "en"
(
language,
language_probability,
task,
all_language_probs,
) = self.get_language_and_tokenizer(audio, task, language)
duration_after_vad = (
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
/ sampling_rate
)
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = (
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
if duration_after_vad
else []
)
all_language_probs = None
# detecting the language if not provided
if language is None:
if not self.model.model.is_multilingual:
language = "en"
language_probability = 1
else:
(
language,
language_probability,
all_language_probs,
) = self.model.detect_language(
features=np.concatenate(
features
+ [
np.full((self.model.model.n_mels, 1), -1.5, dtype="float32")
],
axis=1,
), # add a dummy feature to account for empty audio
language_detection_segments=language_detection_segments,
language_detection_threshold=language_detection_threshold,
)
self.model.logger.info(
"Detected language '%s' with probability %.2f",
language,
language_probability,
)
else:
if not self.model.model.is_multilingual and language != "en":
self.model.logger.warning(
"The current model is English-only but the language parameter is set to '%s'; "
"using 'en' instead." % language
)
language = "en"
language_probability = 1
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
features = (
np.stack([pad_or_trim(feature) for feature in features]) if features else []
)
# batched options: see the difference with default options in WhisperModel
batched_options = TranscriptionOptions(
beam_size=beam_size,
@@ -456,23 +465,6 @@ class BatchedInferencePipeline:
all_language_probs=all_language_probs,
)
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = (
np.stack(
[
pad_or_trim(
self.model.feature_extractor(chunk)[
...,
: chunk.shape[0] // self.model.feature_extractor.hop_length,
]
)
for chunk in audio_chunks
]
)
if duration_after_vad
else []
)
segments = self._batched_segments_generator(
features,
chunks_metadata,
@@ -518,9 +510,6 @@ class BatchedInferencePipeline:
pbar.update(1)
pbar.close()
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
self.last_speech_timestamp = 0.0
@@ -835,11 +824,6 @@ class WhisperModel:
language = "en"
language_probability = 1
else:
if (
language_detection_segments is None
or language_detection_segments < 1
):
language_detection_segments = 1
start_timestamp = (
float(clip_timestamps.split(",")[0])
if isinstance(clip_timestamps, str)
@@ -851,41 +835,15 @@ class WhisperModel:
if start_timestamp * self.frames_per_second < content_frames
else 0
)
end_frames = min(
seek
+ self.feature_extractor.nb_max_frames
* language_detection_segments,
content_frames,
(
language,
language_probability,
all_language_probs,
) = self.detect_language(
features=features[..., seek:],
language_detection_segments=language_detection_segments,
language_detection_threshold=language_detection_threshold,
)
detected_language_info = {}
while seek <= end_frames:
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(pad_or_trim(segment))
# results is a list of tuple[str, float] with language names and
# probabilities.
results = self.model.detect_language(encoder_output)[0]
# Parse language names to strip out markers
all_language_probs = [
(token[2:-2], prob) for (token, prob) in results
]
# Get top language token and probability
language, language_probability = all_language_probs[0]
if language_probability > language_detection_threshold:
break
detected_language_info.setdefault(language, []).append(
language_probability
)
seek += segment.shape[-1]
else:
# If no language detected for all segments, the majority vote of the highest
# projected languages for all segments is used to determine the language.
language = max(
detected_language_info,
key=lambda lang: len(detected_language_info[lang]),
)
language_probability = max(detected_language_info[language])
self.logger.info(
"Detected language '%s' with probability %.2f",
@@ -1782,224 +1740,81 @@ class WhisperModel:
return encoder_output, output
def detect_language(self, audio: np.ndarray):
segment = self.feature_extractor(audio)[
:, : self.feature_extractor.nb_max_frames
def detect_language(
self,
audio: Optional[np.ndarray] = None,
features: Optional[np.ndarray] = None,
vad_filter: bool = False,
vad_parameters: Union[dict, VadOptions] = None,
language_detection_segments: int = 1,
language_detection_threshold: float = 0.5,
) -> Tuple[str, float, List[Tuple[str, float]]]:
"""
Use Whisper to detect the language of the input audio or features.
Arguments:
audio: Input audio signal, must be a 1D float array sampled at 16khz.
features: Input Mel spectrogram features, must be a float array with
shape (n_mels, n_frames), if `audio` is provided, the features will be ignored.
Either `audio` or `features` must be provided.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
language: Detected language.
languege_probability: Probability of the detected language.
all_language_probs: List of tuples with all language names and probabilities.
"""
assert (
audio is not None or features is not None
), "Either `audio` or `features` must be provided."
if audio is not None:
if vad_filter:
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)
audio = audio[
: language_detection_segments * self.feature_extractor.n_samples
]
features = self.feature_extractor(audio)
features = features[
..., : language_detection_segments * self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(pad_or_trim(segment))
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
self.logger.info(
f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..."
)
all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]]
detected_language_info = {}
for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
encoder_output = self.encode(
pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames])
)
# results is a list of tuple[str, float] with language names and probabilities.
results = self.model.detect_language(encoder_output)[0]
# Parse language names to strip out markers
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
# Get top language token and probability
language, language_probability = all_language_probs[0]
if language_probability > language_detection_threshold:
break
detected_language_info.setdefault(language, []).append(language_probability)
else:
# If no language detected for all segments, the majority vote of the highest
# projected languages for all segments is used to determine the language.
language = max(
detected_language_info,
key=lambda lang: len(detected_language_info[lang]),
)
language_probability = max(detected_language_info[language])
return language, language_probability, all_language_probs
def detect_language_multi_segment(
self, audio: Union[str, BinaryIO, np.ndarray], params: Optional[dict] = None
):
"""
Detect language based on N highly-confident segments of a language.
"""
# The threshold is used to decide if the audio is silence or not.
# The default is 0.02 (2.0%) i.e, if more than 2.0% of the audio is silent,
# the audio is considered as silence.
if not params:
params = {
"multilingual": False,
"speech_percentage_threshold": 0.02,
"language_detection_segments": 4,
"vad_filter": True,
"vad_min_silence_duration": 2500,
"language_threshold": 0.7,
}
if params.get("multilingual", False):
logging.warning(
"lang_id is not supported for multilingual audios, detecting the major language."
)
speech_percentage_threshold = params.get("speech_percentage_threshold", 0.02)
language_threshold = params.get("language_threshold", 0.7)
num_detection_segments = params.get("language_detection_segments", 4)
vad_filter_enabled = params.get("vad_filter", True)
vad_params = dict(
min_silence_duration_ms=params.get("vad_min_silence_duration", 2500)
)
if vad_filter_enabled:
vad_params = VadOptions(**vad_params)
# decode audio if it is not decoded already
sampling_rate = self.feature_extractor.sampling_rate
if not isinstance(audio, np.ndarray):
audio: np.ndarray = decode_audio(audio, sampling_rate=sampling_rate)
# calculate duration of audio as number of seconds
# audio.shape[0] is the number of samples in the audio
# sampling_rate is the number of samples per second
# if we divide the number of samples by the number of samples per second,
# we get the duration in seconds
duration = audio.shape[0] / sampling_rate
# Check if vad is enabled, and collect voiced segments
if vad_filter_enabled:
# get chunks of audio that contain speech
speech_chunks = get_speech_timestamps(audio, vad_params)
# merge chunks of audio that contain speech into a single array
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)
# calculate new duration of audio without silence
duration_vad = audio.shape[0] / sampling_rate
logging.debug(
f"Lang ID: VAD filter removed {duration - duration_vad} sec of audio"
)
# if the audio after VAD is less than 2% of the original audio, consider it as silence
if duration_vad / duration < speech_percentage_threshold:
return {"language_code": None, "language_confidence": 1.0}
# update duration to be the duration after VAD
duration = duration_vad
# if the duration of the audio is less than 1 second, consider it as silence
if duration < 1.0:
return {"language_code": None, "language_confidence": 1.0}
# number of feature frames in 30 seconds of audio is 3000
nb_max_frames = self.feature_extractor.nb_max_frames
# extract features from audio with padding (default)
features = self.feature_extractor(audio)
# number of segments in the audio
num_segments = features.shape[-1] // nb_max_frames
# more number of segments than possible with the duration of file
if num_detection_segments > num_segments:
logging.warning(
f"Lang ID: Can not have more segments, setting {num_segments} segments."
)
num_detection_segments = num_segments
# create a list of indices to randomly select segments from
indices = list(range(num_detection_segments))
# fix seed to get deterministic results
random.seed(0)
random.shuffle(indices)
detected_languages = []
all_language_probabilities = defaultdict(list)
confident_language_probabilities = defaultdict(list)
num_confident_segments_per_language = defaultdict(int)
# Iterate over the randomly selected indices of the segments.
#
# For each segment, extract features and detect language.
#
# If the language is confident, add it to the list of confident segments for that language.
#
# If the number of confident segments for a language
# is greater than or equal to the number of detection segments,
# return the language and the average probability of the language.
#
# If we are unable to get sufficient number of confident predcitions,
# return the most frequently detected language with maximum probability.
#
# We need to get sufficient number of confident predictions per language, not in total.
for i in indices:
segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames]
try:
encoder_output = self.encode(pad_or_trim(segment_features))
results = self.model.detect_language(encoder_output)[0]
except ValueError as e: # or RuntimeError
logging.error(f"Inference error:{e}")
# results is the list of classes (languages) and their probabilities (descending),
# for eg: [('<|de|>', 0.482177734375),('<|en|>', 0.283447265625),...]
# take top language token and probability
# and parse language token to strip out markers
# for eg: '<|de|>' -> 'de'
language_token = results[0][0]
language = language_token[2:-2]
language_probability = results[0][1]
detected_languages.append(language)
all_language_probabilities[language].append(language_probability)
# only consider if the language prediction is confident
if language_probability > language_threshold:
num_confident_segments_per_language[language] += 1
# Add language and probability to the list of languages when it is confident
confident_language_probabilities[language].append(language_probability)
# return the language when sufficient number of confident segments is achieved
if (
num_confident_segments_per_language[language]
>= num_detection_segments
):
# Considering the average probability of only confident segments
mean = sum(confident_language_probabilities[language]) / len(
confident_language_probabilities[language]
)
return {
"language_code": language,
"language_confidence": mean,
}
# if we are unable to get sufficient number of confident predictions,
# return the most frequently detected language.
# if there is a tie, return the one with maximum average probability.
counter = Counter(detected_languages)
# Define the key function to select frequent language with attached probabilities
def key_func(language):
# Calculate the frequency of the language
frequency = counter[language]
# Calculate the average probability of the language
prob_avg = sum(all_language_probabilities[language]) / len(
all_language_probabilities[language]
)
return frequency, prob_avg
if detected_languages:
# Use the key function to find the language with maximum frequency and probability
max_language = max(detected_languages, key=key_func)
max_probability = sum(all_language_probabilities[max_language]) / len(
all_language_probabilities[max_language]
)
# Do additional checks for silence for non-confident case
# calculate RMS amplitude and DC offset
dc_offset = audio.mean()
audio_minus_dc_offset = audio - dc_offset
is_silent = (
all(np.abs(audio) < 0.1)
or np.sqrt(np.mean(audio_minus_dc_offset**2)) < 0.01
)
if is_silent:
return {"language_code": None, "language_confidence": 1.0}
return {
"language_code": max_language,
"language_confidence": max_probability,
}
# Language is not detected for any segment and none of prev conditions met
return {"language_code": None, "language_confidence": 1.0}
def restore_speech_timestamps(
segments: Iterable[Segment],

View File

@@ -1,5 +1,7 @@
import os
import numpy as np
from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import get_suppressed_tokens
@@ -87,6 +89,15 @@ def test_batched_transcribe(physcisworks_path):
assert len(segments) > 7
def test_empty_audio():
audio = np.asarray([], dtype="float32")
model = WhisperModel("tiny")
pipeline = BatchedInferencePipeline(model=model)
assert list(model.transcribe(audio)[0]) == []
assert list(pipeline.transcribe(audio)[0]) == []
model.detect_language(audio)
def test_prefix_with_timestamps(jfk_path):
model = WhisperModel("tiny")
segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans")
@@ -147,13 +158,6 @@ def test_stereo_diarization(data_dir):
assert transcription == "The horizon seems extremely distant."
def test_multisegment_lang_id(physcisworks_path):
model = WhisperModel("tiny")
language_info = model.detect_language_multi_segment(physcisworks_path)
assert language_info["language_code"] == "en"
assert language_info["language_confidence"] > 0.8
def test_suppressed_tokens_minus_1():
model = WhisperModel("tiny.en")