mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 13:38:01 -05:00
refactor multilingual option (#1148)
* Added test for `multilingual` option with english-german audio * removed `output_language` argument as it is redundant, you can get the same functionality with `task="translate"` * use the correct `encoder_output` for language detection in sequential transcription * enabled `multilingual` functionality for batched inference
This commit is contained in:
@@ -93,7 +93,6 @@ class TranscriptionOptions:
|
||||
prepend_punctuations: str
|
||||
append_punctuations: str
|
||||
multilingual: bool
|
||||
output_language: Optional[str]
|
||||
max_new_tokens: Optional[int]
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
@@ -210,10 +209,21 @@ class BatchedInferencePipeline:
|
||||
)
|
||||
|
||||
encoder_output = self.model.encode(features)
|
||||
prompts = [prompt.copy() for _ in range(batch_size)]
|
||||
|
||||
if options.multilingual:
|
||||
language_tokens = [
|
||||
tokenizer.tokenizer.token_to_id(segment_langs[0][0])
|
||||
for segment_langs in self.model.model.detect_language(encoder_output)
|
||||
]
|
||||
language_token_index = prompt.index(tokenizer.language)
|
||||
|
||||
for i, language_token in enumerate(language_tokens):
|
||||
prompts[i][language_token_index] = language_token
|
||||
|
||||
results = self.model.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
prompts,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
@@ -279,7 +289,6 @@ class BatchedInferencePipeline:
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
multilingual: bool = False,
|
||||
output_language: Optional[str] = None,
|
||||
vad_filter: bool = True,
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
@@ -322,6 +331,7 @@ class BatchedInferencePipeline:
|
||||
with the next word
|
||||
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||
with the previous word
|
||||
multilingual: Perform language detection on every segment.
|
||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||
without speech. This step is using the Silero VAD model
|
||||
https://github.com/snakers4/silero-vad.
|
||||
@@ -360,10 +370,6 @@ class BatchedInferencePipeline:
|
||||
Arg has effect only if condition_on_previous_text is True. Set at 0.5
|
||||
prefix: Optional text to provide as a prefix at the beginning of each window.
|
||||
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
|
||||
multilingual: If True, perform transcription on multilingual videos. Set as False.
|
||||
output_language: Valid only if multilingual is set to True.
|
||||
Specifies the string representing the output language. One of
|
||||
'en' (English) or 'hybrid' (code-switched transcription). set as None.
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold
|
||||
(in seconds) when a possible hallucination is detected. set as None.
|
||||
@@ -376,6 +382,13 @@ class BatchedInferencePipeline:
|
||||
|
||||
sampling_rate = self.model.feature_extractor.sampling_rate
|
||||
|
||||
if multilingual and not self.model.model.is_multilingual:
|
||||
self.model.logger.warning(
|
||||
"The current model is English-only but the multilingual parameter is set to"
|
||||
"True; setting to False instead."
|
||||
)
|
||||
multilingual = False
|
||||
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
@@ -498,8 +511,7 @@ class BatchedInferencePipeline:
|
||||
condition_on_previous_text=False,
|
||||
clip_timestamps=clip_timestamps,
|
||||
prompt_reset_on_temperature=0.5,
|
||||
multilingual=False,
|
||||
output_language=None,
|
||||
multilingual=multilingual,
|
||||
without_timestamps=without_timestamps,
|
||||
max_initial_timestamp=0.0,
|
||||
)
|
||||
@@ -721,7 +733,6 @@ class WhisperModel:
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
multilingual: bool = False,
|
||||
output_language: Optional[str] = None,
|
||||
vad_filter: bool = False,
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
@@ -781,12 +792,7 @@ class WhisperModel:
|
||||
with the next word
|
||||
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||
with the previous word
|
||||
multilingual: If True, perform transcription on multilingual videos
|
||||
and return the transcript based
|
||||
on the 'output_language' flag.
|
||||
output_language: Valid only if multilingual is set to True.
|
||||
Specifies the string representing the output language. One of
|
||||
'en' (English) or 'hybrid' (code-switched transcription).
|
||||
multilingual: Perform language detection on every segment.
|
||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||
without speech. This step is using the Silero VAD model
|
||||
https://github.com/snakers4/silero-vad.
|
||||
@@ -817,6 +823,13 @@ class WhisperModel:
|
||||
|
||||
sampling_rate = self.feature_extractor.sampling_rate
|
||||
|
||||
if multilingual and not self.model.is_multilingual:
|
||||
self.logger.warning(
|
||||
"The current model is English-only but the multilingual parameter is set to"
|
||||
"True; setting to False instead."
|
||||
)
|
||||
multilingual = False
|
||||
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
|
||||
@@ -863,13 +876,6 @@ class WhisperModel:
|
||||
encoder_output = None
|
||||
all_language_probs = None
|
||||
|
||||
# setting output_language for multilingual videos
|
||||
if multilingual:
|
||||
if output_language is None:
|
||||
output_language = "en"
|
||||
elif output_language not in ["en", "hybrid"]:
|
||||
raise ValueError("Output language needs to be one of 'en'/'hybrid'.")
|
||||
|
||||
# detecting the language if not provided
|
||||
if language is None:
|
||||
if not self.model.is_multilingual:
|
||||
@@ -949,7 +955,6 @@ class WhisperModel:
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
multilingual=multilingual,
|
||||
output_language=output_language,
|
||||
max_new_tokens=max_new_tokens,
|
||||
clip_timestamps=clip_timestamps,
|
||||
hallucination_silence_threshold=hallucination_silence_threshold,
|
||||
@@ -1139,27 +1144,17 @@ class WhisperModel:
|
||||
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
|
||||
if encoder_output is None:
|
||||
if seek > 0 or encoder_output is None:
|
||||
encoder_output = self.encode(segment)
|
||||
|
||||
# Perform language detection at every segment to update task based on output language,
|
||||
# if the language is english, task is transcribe,
|
||||
# else the task is translate to english (default)
|
||||
# or transcribe if 'output_language' is 'hybrid'.
|
||||
if options.multilingual:
|
||||
results = self.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
if options.output_language == "en" and language != "en":
|
||||
task = "translate"
|
||||
else:
|
||||
task = "transcribe"
|
||||
|
||||
# Update tokenizer based on task and language
|
||||
tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>")
|
||||
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
|
||||
tokenizer.language_code = language
|
||||
# Update prompt based on task and language
|
||||
|
||||
prompt = self.get_prompt(
|
||||
tokenizer,
|
||||
previous_tokens,
|
||||
@@ -1168,9 +1163,6 @@ class WhisperModel:
|
||||
hotwords=options.hotwords,
|
||||
)
|
||||
|
||||
if seek > 0 or encoder_output is None:
|
||||
encoder_output = self.encode(segment)
|
||||
|
||||
(
|
||||
result,
|
||||
avg_logprob,
|
||||
|
||||
BIN
tests/data/multilingual.mp3
Normal file
BIN
tests/data/multilingual.mp3
Normal file
Binary file not shown.
@@ -158,6 +158,63 @@ def test_stereo_diarization(data_dir):
|
||||
assert transcription == "The horizon seems extremely distant."
|
||||
|
||||
|
||||
def test_multilingual_transcription(data_dir):
|
||||
model = WhisperModel("tiny")
|
||||
pipeline = BatchedInferencePipeline(model)
|
||||
|
||||
audio_path = os.path.join(data_dir, "multilingual.mp3")
|
||||
audio = decode_audio(audio_path)
|
||||
|
||||
segments, info = model.transcribe(
|
||||
audio,
|
||||
multilingual=True,
|
||||
without_timestamps=True,
|
||||
condition_on_previous_text=False,
|
||||
)
|
||||
segments = list(segments)
|
||||
|
||||
assert (
|
||||
segments[0].text
|
||||
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
|
||||
" software and associated documentation files to deal in the software without restriction,"
|
||||
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
|
||||
", sublicence, and or cell copies of the software, and to permit persons to whom the "
|
||||
"software is furnished to do so, subject to the following conditions. The above copyright"
|
||||
" notice and this permission notice, shall be included in all copies or substantial "
|
||||
"portions of the software."
|
||||
)
|
||||
|
||||
assert (
|
||||
segments[1].text
|
||||
== " Jedem, der dieses Software und die dazu gehöregen Dokumentationsdatein erhält, wird "
|
||||
"hiermit unengeltlich die Genehmigung erteilt, wird der Software und eingeschränkt zu "
|
||||
"verfahren. Dies umfasst insbesondere das Recht, die Software zu verwenden, zu "
|
||||
"vervielfältigen, zu modifizieren, zu Samenzofügen, zu veröffentlichen, zu verteilen, "
|
||||
"unterzulizenzieren und oder kopieren der Software zu verkaufen und diese Rechte "
|
||||
"unterfolgen den Bedingungen anderen zu übertragen."
|
||||
)
|
||||
|
||||
segments, info = pipeline.transcribe(audio, multilingual=True)
|
||||
segments = list(segments)
|
||||
|
||||
assert (
|
||||
segments[0].text
|
||||
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
|
||||
" software and associated documentation files to deal in the software without restriction,"
|
||||
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
|
||||
", sublicence, and or cell copies of the software, and to permit persons to whom the "
|
||||
"software is furnished to do so, subject to the following conditions. The above copyright"
|
||||
" notice and this permission notice, shall be included in all copies or substantial "
|
||||
"portions of the software."
|
||||
)
|
||||
assert (
|
||||
"Dokumentationsdatein erhält, wird hiermit unengeltlich die Genehmigung erteilt,"
|
||||
" wird der Software und eingeschränkt zu verfahren. Dies umfasst insbesondere das Recht,"
|
||||
" die Software zu verwenden, zu vervielfältigen, zu modifizieren"
|
||||
in segments[1].text
|
||||
)
|
||||
|
||||
|
||||
def test_suppressed_tokens_minus_1():
|
||||
model = WhisperModel("tiny.en")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user