refactor multilingual option (#1148)

* Added test for `multilingual` option with english-german audio
* removed `output_language` argument as it is redundant, you can get the same functionality with `task="translate"`
* use the correct `encoder_output` for language detection in sequential transcription
* enabled `multilingual` functionality for batched inference
This commit is contained in:
Mahmoud Ashraf
2024-11-19 23:14:59 +02:00
committed by GitHub
parent be9fb36ed3
commit bcd8ce0fc7
3 changed files with 88 additions and 39 deletions

View File

@@ -93,7 +93,6 @@ class TranscriptionOptions:
prepend_punctuations: str prepend_punctuations: str
append_punctuations: str append_punctuations: str
multilingual: bool multilingual: bool
output_language: Optional[str]
max_new_tokens: Optional[int] max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]] clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float] hallucination_silence_threshold: Optional[float]
@@ -210,10 +209,21 @@ class BatchedInferencePipeline:
) )
encoder_output = self.model.encode(features) encoder_output = self.model.encode(features)
prompts = [prompt.copy() for _ in range(batch_size)]
if options.multilingual:
language_tokens = [
tokenizer.tokenizer.token_to_id(segment_langs[0][0])
for segment_langs in self.model.model.detect_language(encoder_output)
]
language_token_index = prompt.index(tokenizer.language)
for i, language_token in enumerate(language_tokens):
prompts[i][language_token_index] = language_token
results = self.model.model.generate( results = self.model.model.generate(
encoder_output, encoder_output,
[prompt] * batch_size, prompts,
beam_size=options.beam_size, beam_size=options.beam_size,
patience=options.patience, patience=options.patience,
length_penalty=options.length_penalty, length_penalty=options.length_penalty,
@@ -279,7 +289,6 @@ class BatchedInferencePipeline:
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
multilingual: bool = False, multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = True, vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None, vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None, max_new_tokens: Optional[int] = None,
@@ -322,6 +331,7 @@ class BatchedInferencePipeline:
with the next word with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word with the previous word
multilingual: Perform language detection on every segment.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad. https://github.com/snakers4/silero-vad.
@@ -360,10 +370,6 @@ class BatchedInferencePipeline:
Arg has effect only if condition_on_previous_text is True. Set at 0.5 Arg has effect only if condition_on_previous_text is True. Set at 0.5
prefix: Optional text to provide as a prefix at the beginning of each window. prefix: Optional text to provide as a prefix at the beginning of each window.
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
multilingual: If True, perform transcription on multilingual videos. Set as False.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription). set as None.
hallucination_silence_threshold: Optional[float] hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None. (in seconds) when a possible hallucination is detected. set as None.
@@ -376,6 +382,13 @@ class BatchedInferencePipeline:
sampling_rate = self.model.feature_extractor.sampling_rate sampling_rate = self.model.feature_extractor.sampling_rate
if multilingual and not self.model.model.is_multilingual:
self.model.logger.warning(
"The current model is English-only but the multilingual parameter is set to"
"True; setting to False instead."
)
multilingual = False
if not isinstance(audio, np.ndarray): if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate) audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate duration = audio.shape[0] / sampling_rate
@@ -498,8 +511,7 @@ class BatchedInferencePipeline:
condition_on_previous_text=False, condition_on_previous_text=False,
clip_timestamps=clip_timestamps, clip_timestamps=clip_timestamps,
prompt_reset_on_temperature=0.5, prompt_reset_on_temperature=0.5,
multilingual=False, multilingual=multilingual,
output_language=None,
without_timestamps=without_timestamps, without_timestamps=without_timestamps,
max_initial_timestamp=0.0, max_initial_timestamp=0.0,
) )
@@ -721,7 +733,6 @@ class WhisperModel:
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
multilingual: bool = False, multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = False, vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None, vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None, max_new_tokens: Optional[int] = None,
@@ -781,12 +792,7 @@ class WhisperModel:
with the next word with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word with the previous word
multilingual: If True, perform transcription on multilingual videos multilingual: Perform language detection on every segment.
and return the transcript based
on the 'output_language' flag.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription).
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad. https://github.com/snakers4/silero-vad.
@@ -817,6 +823,13 @@ class WhisperModel:
sampling_rate = self.feature_extractor.sampling_rate sampling_rate = self.feature_extractor.sampling_rate
if multilingual and not self.model.is_multilingual:
self.logger.warning(
"The current model is English-only but the multilingual parameter is set to"
"True; setting to False instead."
)
multilingual = False
if not isinstance(audio, np.ndarray): if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate) audio = decode_audio(audio, sampling_rate=sampling_rate)
@@ -863,13 +876,6 @@ class WhisperModel:
encoder_output = None encoder_output = None
all_language_probs = None all_language_probs = None
# setting output_language for multilingual videos
if multilingual:
if output_language is None:
output_language = "en"
elif output_language not in ["en", "hybrid"]:
raise ValueError("Output language needs to be one of 'en'/'hybrid'.")
# detecting the language if not provided # detecting the language if not provided
if language is None: if language is None:
if not self.model.is_multilingual: if not self.model.is_multilingual:
@@ -949,7 +955,6 @@ class WhisperModel:
prepend_punctuations=prepend_punctuations, prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations, append_punctuations=append_punctuations,
multilingual=multilingual, multilingual=multilingual,
output_language=output_language,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps, clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold, hallucination_silence_threshold=hallucination_silence_threshold,
@@ -1139,27 +1144,17 @@ class WhisperModel:
previous_tokens = all_tokens[prompt_reset_since:] previous_tokens = all_tokens[prompt_reset_since:]
if encoder_output is None: if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment) encoder_output = self.encode(segment)
# Perform language detection at every segment to update task based on output language,
# if the language is english, task is transcribe,
# else the task is translate to english (default)
# or transcribe if 'output_language' is 'hybrid'.
if options.multilingual: if options.multilingual:
results = self.model.detect_language(encoder_output) results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0] language_token, language_probability = results[0][0]
language = language_token[2:-2] language = language_token[2:-2]
if options.output_language == "en" and language != "en":
task = "translate"
else:
task = "transcribe"
# Update tokenizer based on task and language
tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>")
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token) tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
tokenizer.language_code = language tokenizer.language_code = language
# Update prompt based on task and language
prompt = self.get_prompt( prompt = self.get_prompt(
tokenizer, tokenizer,
previous_tokens, previous_tokens,
@@ -1168,9 +1163,6 @@ class WhisperModel:
hotwords=options.hotwords, hotwords=options.hotwords,
) )
if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment)
( (
result, result,
avg_logprob, avg_logprob,

BIN
tests/data/multilingual.mp3 Normal file

Binary file not shown.

View File

@@ -158,6 +158,63 @@ def test_stereo_diarization(data_dir):
assert transcription == "The horizon seems extremely distant." assert transcription == "The horizon seems extremely distant."
def test_multilingual_transcription(data_dir):
model = WhisperModel("tiny")
pipeline = BatchedInferencePipeline(model)
audio_path = os.path.join(data_dir, "multilingual.mp3")
audio = decode_audio(audio_path)
segments, info = model.transcribe(
audio,
multilingual=True,
without_timestamps=True,
condition_on_previous_text=False,
)
segments = list(segments)
assert (
segments[0].text
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
" software and associated documentation files to deal in the software without restriction,"
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
", sublicence, and or cell copies of the software, and to permit persons to whom the "
"software is furnished to do so, subject to the following conditions. The above copyright"
" notice and this permission notice, shall be included in all copies or substantial "
"portions of the software."
)
assert (
segments[1].text
== " Jedem, der dieses Software und die dazu gehöregen Dokumentationsdatein erhält, wird "
"hiermit unengeltlich die Genehmigung erteilt, wird der Software und eingeschränkt zu "
"verfahren. Dies umfasst insbesondere das Recht, die Software zu verwenden, zu "
"vervielfältigen, zu modifizieren, zu Samenzofügen, zu veröffentlichen, zu verteilen, "
"unterzulizenzieren und oder kopieren der Software zu verkaufen und diese Rechte "
"unterfolgen den Bedingungen anderen zu übertragen."
)
segments, info = pipeline.transcribe(audio, multilingual=True)
segments = list(segments)
assert (
segments[0].text
== " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
" software and associated documentation files to deal in the software without restriction,"
" including without limitation the rights to use, copy, modify, merge, publish, distribute"
", sublicence, and or cell copies of the software, and to permit persons to whom the "
"software is furnished to do so, subject to the following conditions. The above copyright"
" notice and this permission notice, shall be included in all copies or substantial "
"portions of the software."
)
assert (
"Dokumentationsdatein erhält, wird hiermit unengeltlich die Genehmigung erteilt,"
" wird der Software und eingeschränkt zu verfahren. Dies umfasst insbesondere das Recht,"
" die Software zu verwenden, zu vervielfältigen, zu modifizieren"
in segments[1].text
)
def test_suppressed_tokens_minus_1(): def test_suppressed_tokens_minus_1():
model = WhisperModel("tiny.en") model = WhisperModel("tiny.en")