mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 13:38:01 -05:00
refactor multilingual option (#1148)
* Added test for `multilingual` option with english-german audio * removed `output_language` argument as it is redundant, you can get the same functionality with `task="translate"` * use the correct `encoder_output` for language detection in sequential transcription * enabled `multilingual` functionality for batched inference
This commit is contained in:
@@ -93,7 +93,6 @@ class TranscriptionOptions:
|
|||||||
prepend_punctuations: str
|
prepend_punctuations: str
|
||||||
append_punctuations: str
|
append_punctuations: str
|
||||||
multilingual: bool
|
multilingual: bool
|
||||||
output_language: Optional[str]
|
|
||||||
max_new_tokens: Optional[int]
|
max_new_tokens: Optional[int]
|
||||||
clip_timestamps: Union[str, List[float]]
|
clip_timestamps: Union[str, List[float]]
|
||||||
hallucination_silence_threshold: Optional[float]
|
hallucination_silence_threshold: Optional[float]
|
||||||
@@ -210,10 +209,21 @@ class BatchedInferencePipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder_output = self.model.encode(features)
|
encoder_output = self.model.encode(features)
|
||||||
|
prompts = [prompt.copy() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
if options.multilingual:
|
||||||
|
language_tokens = [
|
||||||
|
tokenizer.tokenizer.token_to_id(segment_langs[0][0])
|
||||||
|
for segment_langs in self.model.model.detect_language(encoder_output)
|
||||||
|
]
|
||||||
|
language_token_index = prompt.index(tokenizer.language)
|
||||||
|
|
||||||
|
for i, language_token in enumerate(language_tokens):
|
||||||
|
prompts[i][language_token_index] = language_token
|
||||||
|
|
||||||
results = self.model.model.generate(
|
results = self.model.model.generate(
|
||||||
encoder_output,
|
encoder_output,
|
||||||
[prompt] * batch_size,
|
prompts,
|
||||||
beam_size=options.beam_size,
|
beam_size=options.beam_size,
|
||||||
patience=options.patience,
|
patience=options.patience,
|
||||||
length_penalty=options.length_penalty,
|
length_penalty=options.length_penalty,
|
||||||
@@ -279,7 +289,6 @@ class BatchedInferencePipeline:
|
|||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
multilingual: bool = False,
|
multilingual: bool = False,
|
||||||
output_language: Optional[str] = None,
|
|
||||||
vad_filter: bool = True,
|
vad_filter: bool = True,
|
||||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: Optional[int] = None,
|
||||||
@@ -322,6 +331,7 @@ class BatchedInferencePipeline:
|
|||||||
with the next word
|
with the next word
|
||||||
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||||
with the previous word
|
with the previous word
|
||||||
|
multilingual: Perform language detection on every segment.
|
||||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||||
without speech. This step is using the Silero VAD model
|
without speech. This step is using the Silero VAD model
|
||||||
https://github.com/snakers4/silero-vad.
|
https://github.com/snakers4/silero-vad.
|
||||||
@@ -360,10 +370,6 @@ class BatchedInferencePipeline:
|
|||||||
Arg has effect only if condition_on_previous_text is True. Set at 0.5
|
Arg has effect only if condition_on_previous_text is True. Set at 0.5
|
||||||
prefix: Optional text to provide as a prefix at the beginning of each window.
|
prefix: Optional text to provide as a prefix at the beginning of each window.
|
||||||
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
|
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
|
||||||
multilingual: If True, perform transcription on multilingual videos. Set as False.
|
|
||||||
output_language: Valid only if multilingual is set to True.
|
|
||||||
Specifies the string representing the output language. One of
|
|
||||||
'en' (English) or 'hybrid' (code-switched transcription). set as None.
|
|
||||||
hallucination_silence_threshold: Optional[float]
|
hallucination_silence_threshold: Optional[float]
|
||||||
When word_timestamps is True, skip silent periods longer than this threshold
|
When word_timestamps is True, skip silent periods longer than this threshold
|
||||||
(in seconds) when a possible hallucination is detected. set as None.
|
(in seconds) when a possible hallucination is detected. set as None.
|
||||||
@@ -376,6 +382,13 @@ class BatchedInferencePipeline:
|
|||||||
|
|
||||||
sampling_rate = self.model.feature_extractor.sampling_rate
|
sampling_rate = self.model.feature_extractor.sampling_rate
|
||||||
|
|
||||||
|
if multilingual and not self.model.model.is_multilingual:
|
||||||
|
self.model.logger.warning(
|
||||||
|
"The current model is English-only but the multilingual parameter is set to"
|
||||||
|
"True; setting to False instead."
|
||||||
|
)
|
||||||
|
multilingual = False
|
||||||
|
|
||||||
if not isinstance(audio, np.ndarray):
|
if not isinstance(audio, np.ndarray):
|
||||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||||
duration = audio.shape[0] / sampling_rate
|
duration = audio.shape[0] / sampling_rate
|
||||||
@@ -498,8 +511,7 @@ class BatchedInferencePipeline:
|
|||||||
condition_on_previous_text=False,
|
condition_on_previous_text=False,
|
||||||
clip_timestamps=clip_timestamps,
|
clip_timestamps=clip_timestamps,
|
||||||
prompt_reset_on_temperature=0.5,
|
prompt_reset_on_temperature=0.5,
|
||||||
multilingual=False,
|
multilingual=multilingual,
|
||||||
output_language=None,
|
|
||||||
without_timestamps=without_timestamps,
|
without_timestamps=without_timestamps,
|
||||||
max_initial_timestamp=0.0,
|
max_initial_timestamp=0.0,
|
||||||
)
|
)
|
||||||
@@ -721,7 +733,6 @@ class WhisperModel:
|
|||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
multilingual: bool = False,
|
multilingual: bool = False,
|
||||||
output_language: Optional[str] = None,
|
|
||||||
vad_filter: bool = False,
|
vad_filter: bool = False,
|
||||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: Optional[int] = None,
|
||||||
@@ -781,12 +792,7 @@ class WhisperModel:
|
|||||||
with the next word
|
with the next word
|
||||||
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||||
with the previous word
|
with the previous word
|
||||||
multilingual: If True, perform transcription on multilingual videos
|
multilingual: Perform language detection on every segment.
|
||||||
and return the transcript based
|
|
||||||
on the 'output_language' flag.
|
|
||||||
output_language: Valid only if multilingual is set to True.
|
|
||||||
Specifies the string representing the output language. One of
|
|
||||||
'en' (English) or 'hybrid' (code-switched transcription).
|
|
||||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||||
without speech. This step is using the Silero VAD model
|
without speech. This step is using the Silero VAD model
|
||||||
https://github.com/snakers4/silero-vad.
|
https://github.com/snakers4/silero-vad.
|
||||||
@@ -817,6 +823,13 @@ class WhisperModel:
|
|||||||
|
|
||||||
sampling_rate = self.feature_extractor.sampling_rate
|
sampling_rate = self.feature_extractor.sampling_rate
|
||||||
|
|
||||||
|
if multilingual and not self.model.is_multilingual:
|
||||||
|
self.logger.warning(
|
||||||
|
"The current model is English-only but the multilingual parameter is set to"
|
||||||
|
"True; setting to False instead."
|
||||||
|
)
|
||||||
|
multilingual = False
|
||||||
|
|
||||||
if not isinstance(audio, np.ndarray):
|
if not isinstance(audio, np.ndarray):
|
||||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||||
|
|
||||||
@@ -863,13 +876,6 @@ class WhisperModel:
|
|||||||
encoder_output = None
|
encoder_output = None
|
||||||
all_language_probs = None
|
all_language_probs = None
|
||||||
|
|
||||||
# setting output_language for multilingual videos
|
|
||||||
if multilingual:
|
|
||||||
if output_language is None:
|
|
||||||
output_language = "en"
|
|
||||||
elif output_language not in ["en", "hybrid"]:
|
|
||||||
raise ValueError("Output language needs to be one of 'en'/'hybrid'.")
|
|
||||||
|
|
||||||
# detecting the language if not provided
|
# detecting the language if not provided
|
||||||
if language is None:
|
if language is None:
|
||||||
if not self.model.is_multilingual:
|
if not self.model.is_multilingual:
|
||||||
@@ -949,7 +955,6 @@ class WhisperModel:
|
|||||||
prepend_punctuations=prepend_punctuations,
|
prepend_punctuations=prepend_punctuations,
|
||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
multilingual=multilingual,
|
multilingual=multilingual,
|
||||||
output_language=output_language,
|
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
clip_timestamps=clip_timestamps,
|
clip_timestamps=clip_timestamps,
|
||||||
hallucination_silence_threshold=hallucination_silence_threshold,
|
hallucination_silence_threshold=hallucination_silence_threshold,
|
||||||
@@ -1139,27 +1144,17 @@ class WhisperModel:
|
|||||||
|
|
||||||
previous_tokens = all_tokens[prompt_reset_since:]
|
previous_tokens = all_tokens[prompt_reset_since:]
|
||||||
|
|
||||||
if encoder_output is None:
|
if seek > 0 or encoder_output is None:
|
||||||
encoder_output = self.encode(segment)
|
encoder_output = self.encode(segment)
|
||||||
|
|
||||||
# Perform language detection at every segment to update task based on output language,
|
|
||||||
# if the language is english, task is transcribe,
|
|
||||||
# else the task is translate to english (default)
|
|
||||||
# or transcribe if 'output_language' is 'hybrid'.
|
|
||||||
if options.multilingual:
|
if options.multilingual:
|
||||||
results = self.model.detect_language(encoder_output)
|
results = self.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
if options.output_language == "en" and language != "en":
|
|
||||||
task = "translate"
|
|
||||||
else:
|
|
||||||
task = "transcribe"
|
|
||||||
|
|
||||||
# Update tokenizer based on task and language
|
|
||||||
tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>")
|
|
||||||
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
|
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
|
||||||
tokenizer.language_code = language
|
tokenizer.language_code = language
|
||||||
# Update prompt based on task and language
|
|
||||||
prompt = self.get_prompt(
|
prompt = self.get_prompt(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
previous_tokens,
|
previous_tokens,
|
||||||
@@ -1168,9 +1163,6 @@ class WhisperModel:
|
|||||||
hotwords=options.hotwords,
|
hotwords=options.hotwords,
|
||||||
)
|
)
|
||||||
|
|
||||||
if seek > 0 or encoder_output is None:
|
|
||||||
encoder_output = self.encode(segment)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
result,
|
result,
|
||||||
avg_logprob,
|
avg_logprob,
|
||||||
|
|||||||
BIN
tests/data/multilingual.mp3
Normal file
BIN
tests/data/multilingual.mp3
Normal file
Binary file not shown.
@@ -158,6 +158,63 @@ def test_stereo_diarization(data_dir):
|
|||||||
assert transcription == "The horizon seems extremely distant."
|
assert transcription == "The horizon seems extremely distant."
|
||||||
|
|
||||||
|
|
||||||
|
def test_multilingual_transcription(data_dir):
|
||||||
|
model = WhisperModel("tiny")
|
||||||
|
pipeline = BatchedInferencePipeline(model)
|
||||||
|
|
||||||
|
audio_path = os.path.join(data_dir, "multilingual.mp3")
|
||||||
|
audio = decode_audio(audio_path)
|
||||||
|
|
||||||
|
segments, info = model.transcribe(
|
||||||
|
audio,
|
||||||
|
multilingual=True,
|
||||||
|
without_timestamps=True,
|
||||||
|
condition_on_previous_text=False,
|
||||||
|
)
|
||||||
|
segments = list(segments)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
segments[0].text
|
||||||
|
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
|
||||||
|
" software and associated documentation files to deal in the software without restriction,"
|
||||||
|
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
|
||||||
|
", sublicence, and or cell copies of the software, and to permit persons to whom the "
|
||||||
|
"software is furnished to do so, subject to the following conditions. The above copyright"
|
||||||
|
" notice and this permission notice, shall be included in all copies or substantial "
|
||||||
|
"portions of the software."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
segments[1].text
|
||||||
|
== " Jedem, der dieses Software und die dazu gehöregen Dokumentationsdatein erhält, wird "
|
||||||
|
"hiermit unengeltlich die Genehmigung erteilt, wird der Software und eingeschränkt zu "
|
||||||
|
"verfahren. Dies umfasst insbesondere das Recht, die Software zu verwenden, zu "
|
||||||
|
"vervielfältigen, zu modifizieren, zu Samenzofügen, zu veröffentlichen, zu verteilen, "
|
||||||
|
"unterzulizenzieren und oder kopieren der Software zu verkaufen und diese Rechte "
|
||||||
|
"unterfolgen den Bedingungen anderen zu übertragen."
|
||||||
|
)
|
||||||
|
|
||||||
|
segments, info = pipeline.transcribe(audio, multilingual=True)
|
||||||
|
segments = list(segments)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
segments[0].text
|
||||||
|
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
|
||||||
|
" software and associated documentation files to deal in the software without restriction,"
|
||||||
|
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
|
||||||
|
", sublicence, and or cell copies of the software, and to permit persons to whom the "
|
||||||
|
"software is furnished to do so, subject to the following conditions. The above copyright"
|
||||||
|
" notice and this permission notice, shall be included in all copies or substantial "
|
||||||
|
"portions of the software."
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"Dokumentationsdatein erhält, wird hiermit unengeltlich die Genehmigung erteilt,"
|
||||||
|
" wird der Software und eingeschränkt zu verfahren. Dies umfasst insbesondere das Recht,"
|
||||||
|
" die Software zu verwenden, zu vervielfältigen, zu modifizieren"
|
||||||
|
in segments[1].text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_suppressed_tokens_minus_1():
|
def test_suppressed_tokens_minus_1():
|
||||||
model = WhisperModel("tiny.en")
|
model = WhisperModel("tiny.en")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user