refactor multilingual option (#1148)

* Added test for `multilingual` option with english-german audio
* removed `output_language` argument as it is redundant, you can get the same functionality with `task="translate"`
* use the correct `encoder_output` for language detection in sequential transcription
* enabled `multilingual` functionality for batched inference
This commit is contained in:
Mahmoud Ashraf
2024-11-19 23:14:59 +02:00
committed by GitHub
parent be9fb36ed3
commit bcd8ce0fc7
3 changed files with 88 additions and 39 deletions

View File

@@ -93,7 +93,6 @@ class TranscriptionOptions:
prepend_punctuations: str
append_punctuations: str
multilingual: bool
output_language: Optional[str]
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]
@@ -210,10 +209,21 @@ class BatchedInferencePipeline:
)
encoder_output = self.model.encode(features)
prompts = [prompt.copy() for _ in range(batch_size)]
if options.multilingual:
language_tokens = [
tokenizer.tokenizer.token_to_id(segment_langs[0][0])
for segment_langs in self.model.model.detect_language(encoder_output)
]
language_token_index = prompt.index(tokenizer.language)
for i, language_token in enumerate(language_tokens):
prompts[i][language_token_index] = language_token
results = self.model.model.generate(
encoder_output,
[prompt] * batch_size,
prompts,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
@@ -279,7 +289,6 @@ class BatchedInferencePipeline:
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
@@ -322,6 +331,7 @@ class BatchedInferencePipeline:
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
multilingual: Perform language detection on every segment.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
@@ -360,10 +370,6 @@ class BatchedInferencePipeline:
Arg has effect only if condition_on_previous_text is True. Set at 0.5
prefix: Optional text to provide as a prefix at the beginning of each window.
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
multilingual: If True, perform transcription on multilingual videos. Set as False.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription). set as None.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
@@ -376,6 +382,13 @@ class BatchedInferencePipeline:
sampling_rate = self.model.feature_extractor.sampling_rate
if multilingual and not self.model.model.is_multilingual:
self.model.logger.warning(
"The current model is English-only but the multilingual parameter is set to"
"True; setting to False instead."
)
multilingual = False
if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
@@ -498,8 +511,7 @@ class BatchedInferencePipeline:
condition_on_previous_text=False,
clip_timestamps=clip_timestamps,
prompt_reset_on_temperature=0.5,
multilingual=False,
output_language=None,
multilingual=multilingual,
without_timestamps=without_timestamps,
max_initial_timestamp=0.0,
)
@@ -721,7 +733,6 @@ class WhisperModel:
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
@@ -781,12 +792,7 @@ class WhisperModel:
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
multilingual: If True, perform transcription on multilingual videos
and return the transcript based
on the 'output_language' flag.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription).
multilingual: Perform language detection on every segment.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
@@ -817,6 +823,13 @@ class WhisperModel:
sampling_rate = self.feature_extractor.sampling_rate
if multilingual and not self.model.is_multilingual:
self.logger.warning(
"The current model is English-only but the multilingual parameter is set to"
"True; setting to False instead."
)
multilingual = False
if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)
@@ -863,13 +876,6 @@ class WhisperModel:
encoder_output = None
all_language_probs = None
# setting output_language for multilingual videos
if multilingual:
if output_language is None:
output_language = "en"
elif output_language not in ["en", "hybrid"]:
raise ValueError("Output language needs to be one of 'en'/'hybrid'.")
# detecting the language if not provided
if language is None:
if not self.model.is_multilingual:
@@ -949,7 +955,6 @@ class WhisperModel:
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
multilingual=multilingual,
output_language=output_language,
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
@@ -1139,27 +1144,17 @@ class WhisperModel:
previous_tokens = all_tokens[prompt_reset_since:]
if encoder_output is None:
if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment)
# Perform language detection at every segment to update task based on output language,
# if the language is english, task is transcribe,
# else the task is translate to english (default)
# or transcribe if 'output_language' is 'hybrid'.
if options.multilingual:
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
if options.output_language == "en" and language != "en":
task = "translate"
else:
task = "transcribe"
# Update tokenizer based on task and language
tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>")
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
tokenizer.language_code = language
# Update prompt based on task and language
prompt = self.get_prompt(
tokenizer,
previous_tokens,
@@ -1168,9 +1163,6 @@ class WhisperModel:
hotwords=options.hotwords,
)
if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment)
(
result,
avg_logprob,