mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-08 13:14:00 -05:00
Cleanup of BatchedInferencePipeline (#1135)
This commit is contained in:
@@ -57,7 +57,7 @@ class Segment:
|
||||
compression_ratio: float
|
||||
no_speech_prob: float
|
||||
words: Optional[List[Word]]
|
||||
temperature: Optional[float] = 1.0
|
||||
temperature: Optional[float]
|
||||
|
||||
def _asdict(self):
|
||||
warn(
|
||||
@@ -68,7 +68,6 @@ class Segment:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
# Added additional parameters for multilingual videos and fixes below
|
||||
@dataclass
|
||||
class TranscriptionOptions:
|
||||
beam_size: int
|
||||
@@ -112,34 +111,17 @@ class TranscriptionInfo:
|
||||
vad_options: VadOptions
|
||||
|
||||
|
||||
# The code below is originally from HF pipeline and is used in whisper-x
|
||||
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
|
||||
|
||||
|
||||
class BatchedInferencePipeline:
|
||||
"""
|
||||
Huggingface Pipeline wrapper for WhisperModel.
|
||||
Copyright (c) 2022, Max Bain
|
||||
All rights reserved.
|
||||
Modified by Mobius Labs GmbH
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
options: Optional[TranscriptionOptions] = None,
|
||||
tokenizer=None,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
self.model: WhisperModel = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self.last_speech_timestamp = 0.0
|
||||
|
||||
def forward(self, features, chunks_metadata, **forward_params):
|
||||
encoder_output, outputs = self.model.generate_segment_batched(
|
||||
features, self.tokenizer, forward_params
|
||||
def forward(self, features, tokenizer, chunks_metadata, options):
|
||||
encoder_output, outputs = self.generate_segment_batched(
|
||||
features, tokenizer, options
|
||||
)
|
||||
|
||||
segmented_outputs = []
|
||||
@@ -153,7 +135,7 @@ class BatchedInferencePipeline:
|
||||
seek,
|
||||
single_timestamp_ending,
|
||||
) = self.model._split_segments_by_timestamps(
|
||||
tokenizer=self.tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
tokens=output["tokens"],
|
||||
time_offset=chunk_metadata["start_time"],
|
||||
segment_size=segment_size,
|
||||
@@ -163,14 +145,14 @@ class BatchedInferencePipeline:
|
||||
segmented_outputs.append(
|
||||
[
|
||||
dict(
|
||||
text=self.tokenizer.decode(subsegment["tokens"]),
|
||||
text=tokenizer.decode(subsegment["tokens"]),
|
||||
avg_logprob=output["avg_logprob"],
|
||||
no_speech_prob=output["no_speech_prob"],
|
||||
tokens=subsegment["tokens"],
|
||||
start=subsegment["start"],
|
||||
end=subsegment["end"],
|
||||
compression_ratio=get_compression_ratio(
|
||||
self.tokenizer.decode(subsegment["tokens"])
|
||||
tokenizer.decode(subsegment["tokens"])
|
||||
),
|
||||
seek=int(
|
||||
chunk_metadata["start_time"] * self.model.frames_per_second
|
||||
@@ -179,19 +161,88 @@ class BatchedInferencePipeline:
|
||||
for subsegment in subsegments
|
||||
]
|
||||
)
|
||||
if forward_params["word_timestamps"]:
|
||||
if options.word_timestamps:
|
||||
self.last_speech_timestamp = self.model.add_word_timestamps(
|
||||
segmented_outputs,
|
||||
self.tokenizer,
|
||||
tokenizer,
|
||||
encoder_output,
|
||||
segment_sizes,
|
||||
forward_params["prepend_punctuations"],
|
||||
forward_params["append_punctuations"],
|
||||
options.prepend_punctuations,
|
||||
options.append_punctuations,
|
||||
self.last_speech_timestamp,
|
||||
)
|
||||
|
||||
return segmented_outputs
|
||||
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
):
|
||||
batch_size = features.shape[0]
|
||||
|
||||
prompt = self.model.get_prompt(
|
||||
tokenizer,
|
||||
previous_tokens=(
|
||||
tokenizer.encode(options.initial_prompt)
|
||||
if options.initial_prompt is not None
|
||||
else []
|
||||
),
|
||||
without_timestamps=options.without_timestamps,
|
||||
hotwords=options.hotwords,
|
||||
)
|
||||
|
||||
if options.max_new_tokens is not None:
|
||||
max_length = len(prompt) + options.max_new_tokens
|
||||
else:
|
||||
max_length = self.model.max_length
|
||||
|
||||
if max_length > self.model.max_length:
|
||||
raise ValueError(
|
||||
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
|
||||
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
|
||||
f"and `max_new_tokens` is: {max_length}. This exceeds the "
|
||||
f"`max_length` of the Whisper model: {self.model.max_length}. "
|
||||
"You should either reduce the length of your prompt, or "
|
||||
"reduce the value of `max_new_tokens`, "
|
||||
f"so that their combined length is less that {self.model.max_length}."
|
||||
)
|
||||
|
||||
encoder_output = self.model.encode(features)
|
||||
|
||||
results = self.model.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
return_scores=True,
|
||||
return_no_speech_prob=True,
|
||||
sampling_temperature=options.temperatures[0],
|
||||
repetition_penalty=options.repetition_penalty,
|
||||
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
||||
)
|
||||
|
||||
output = []
|
||||
for result in results:
|
||||
# return scores
|
||||
seq_len = len(result.sequences_ids[0])
|
||||
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
||||
|
||||
output.append(
|
||||
dict(
|
||||
avg_logprob=cum_logprob / (seq_len + 1),
|
||||
no_speech_prob=result.no_speech_prob,
|
||||
tokens=result.sequences_ids[0],
|
||||
)
|
||||
)
|
||||
|
||||
return encoder_output, output
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
@@ -216,20 +267,26 @@ class BatchedInferencePipeline:
|
||||
log_prob_threshold: Optional[float] = -1.0,
|
||||
log_prob_low_threshold: Optional[float] = None,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
prompt_reset_on_temperature: float = 0.5,
|
||||
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
suppress_blank: bool = True,
|
||||
suppress_tokens: Optional[List[int]] = [-1],
|
||||
without_timestamps: bool = True,
|
||||
max_initial_timestamp: float = 1.0,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
multilingual: bool = False,
|
||||
output_language: Optional[str] = None,
|
||||
vad_filter: bool = True,
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
chunk_length: Optional[int] = None,
|
||||
clip_timestamps: Optional[List[dict]] = None,
|
||||
batch_size: int = 16,
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
batch_size: int = 8,
|
||||
hotwords: Optional[str] = None,
|
||||
language_detection_threshold: Optional[float] = 0.5,
|
||||
language_detection_segments: int = 1,
|
||||
@@ -250,22 +307,10 @@ class BatchedInferencePipeline:
|
||||
repetition_penalty: Penalty applied to the score of previously generated tokens
|
||||
(set > 1 to penalize).
|
||||
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
|
||||
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
||||
which will be successively used upon failures according to either
|
||||
`compression_ratio_threshold` or `log_prob_threshold`.
|
||||
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
||||
treat as failed.
|
||||
log_prob_threshold: If the average log probability over sampled tokens is
|
||||
below this value, treat as failed.
|
||||
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
|
||||
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
|
||||
This value should be less than log_prob_threshold.
|
||||
no_speech_threshold: If the no_speech probability is higher than this value AND
|
||||
the average log probability over sampled tokens is below `log_prob_threshold`,
|
||||
consider the segment as silent.
|
||||
temperature: Temperature for sampling. If a list or tuple is passed,
|
||||
only the first value is used.
|
||||
initial_prompt: Optional text string or iterable of token ids to provide as a
|
||||
prompt for the first window.
|
||||
prefix: Optional text to provide as a prefix for the first window.
|
||||
prompt for the each window.
|
||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||
of symbols as defined in `tokenizer.non_speech_tokens()`.
|
||||
@@ -296,29 +341,32 @@ class BatchedInferencePipeline:
|
||||
higher than this value, the language is detected.
|
||||
language_detection_segments: Number of segments to consider for the language detection.
|
||||
|
||||
Static params: (Fixed for batched version)
|
||||
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
|
||||
multilingual: If True, perform transcription on multilingual videos. Set as False.
|
||||
output_language: Valid only if multilingual is set to True.
|
||||
Specifies the string representing the output language. One of
|
||||
'en' (English) or 'hybrid' (code-switched transcription). set as None.
|
||||
Unused Arguments
|
||||
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
||||
treat as failed.
|
||||
log_prob_threshold: If the average log probability over sampled tokens is
|
||||
below this value, treat as failed.
|
||||
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
|
||||
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
|
||||
This value should be less than log_prob_threshold.
|
||||
no_speech_threshold: If the no_speech probability is higher than this value AND
|
||||
the average log probability over sampled tokens is below `log_prob_threshold`,
|
||||
consider the segment as silent.
|
||||
condition_on_previous_text: If True, the previous output of the model is provided
|
||||
as a prompt for the next window; disabling may make the text inconsistent across
|
||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||
such as repetition looping or timestamps going out of sync. Set as False
|
||||
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
|
||||
Arg has effect only if condition_on_previous_text is True. Set at 0.5
|
||||
#TODO: support "hallucination_silence_threshold" when "word_timestamps=True"
|
||||
prefix: Optional text to provide as a prefix at the beginning of each window.
|
||||
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
|
||||
multilingual: If True, perform transcription on multilingual videos. Set as False.
|
||||
output_language: Valid only if multilingual is set to True.
|
||||
Specifies the string representing the output language. One of
|
||||
'en' (English) or 'hybrid' (code-switched transcription). set as None.
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold
|
||||
(in seconds) when a possible hallucination is detected. set as None.
|
||||
|
||||
unused:
|
||||
language_detection_threshold: If the maximum probability of the language tokens is
|
||||
higher than this value, the language is detected.
|
||||
language_detection_segments: Number of segments to consider for the language detection.
|
||||
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
|
||||
@@ -410,7 +458,7 @@ class BatchedInferencePipeline:
|
||||
|
||||
language_probability = 1
|
||||
|
||||
self.tokenizer = Tokenizer(
|
||||
tokenizer = Tokenizer(
|
||||
self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual,
|
||||
task=task,
|
||||
@@ -421,8 +469,7 @@ class BatchedInferencePipeline:
|
||||
np.stack([pad_or_trim(feature) for feature in features]) if features else []
|
||||
)
|
||||
|
||||
# batched options: see the difference with default options in WhisperModel
|
||||
batched_options = TranscriptionOptions(
|
||||
options = TranscriptionOptions(
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
patience=patience,
|
||||
@@ -434,12 +481,14 @@ class BatchedInferencePipeline:
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
compression_ratio_threshold=compression_ratio_threshold,
|
||||
temperatures=(
|
||||
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
||||
temperature[:1]
|
||||
if isinstance(temperature, (list, tuple))
|
||||
else [temperature]
|
||||
),
|
||||
initial_prompt=initial_prompt,
|
||||
prefix=prefix,
|
||||
suppress_blank=suppress_blank,
|
||||
suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens),
|
||||
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
max_new_tokens=max_new_tokens,
|
||||
@@ -447,7 +496,7 @@ class BatchedInferencePipeline:
|
||||
word_timestamps=word_timestamps,
|
||||
hallucination_silence_threshold=None,
|
||||
condition_on_previous_text=False,
|
||||
clip_timestamps="0",
|
||||
clip_timestamps=clip_timestamps,
|
||||
prompt_reset_on_temperature=0.5,
|
||||
multilingual=False,
|
||||
output_language=None,
|
||||
@@ -460,31 +509,33 @@ class BatchedInferencePipeline:
|
||||
language_probability=language_probability,
|
||||
duration=duration,
|
||||
duration_after_vad=duration_after_vad,
|
||||
transcription_options=batched_options,
|
||||
vad_options=None,
|
||||
transcription_options=options,
|
||||
vad_options=vad_parameters,
|
||||
all_language_probs=all_language_probs,
|
||||
)
|
||||
|
||||
segments = self._batched_segments_generator(
|
||||
features,
|
||||
tokenizer,
|
||||
chunks_metadata,
|
||||
batch_size,
|
||||
batched_options,
|
||||
options,
|
||||
log_progress,
|
||||
)
|
||||
|
||||
return segments, info
|
||||
|
||||
def _batched_segments_generator(
|
||||
self, features, chunks_metadata, batch_size, options, log_progress
|
||||
self, features, tokenizer, chunks_metadata, batch_size, options, log_progress
|
||||
):
|
||||
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
|
||||
seg_idx = 0
|
||||
for i in range(0, len(features), batch_size):
|
||||
results = self.forward(
|
||||
features[i : i + batch_size],
|
||||
tokenizer,
|
||||
chunks_metadata[i : i + batch_size],
|
||||
**asdict(options),
|
||||
options,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
@@ -505,6 +556,7 @@ class BatchedInferencePipeline:
|
||||
avg_logprob=segment["avg_logprob"],
|
||||
no_speech_prob=segment["no_speech_prob"],
|
||||
compression_ratio=segment["compression_ratio"],
|
||||
temperature=options.temperatures[0],
|
||||
)
|
||||
|
||||
pbar.update(1)
|
||||
@@ -1689,57 +1741,6 @@ class WhisperModel:
|
||||
)
|
||||
return return_list
|
||||
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: dict,
|
||||
):
|
||||
batch_size = features.shape[0]
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
if options["initial_prompt"] is not None:
|
||||
initial_prompt = " " + options["initial_prompt"].strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(
|
||||
tokenizer,
|
||||
previous_tokens,
|
||||
without_timestamps=options["without_timestamps"],
|
||||
prefix=options["prefix"],
|
||||
)
|
||||
|
||||
encoder_output = self.encode(features)
|
||||
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options["beam_size"],
|
||||
patience=options["patience"],
|
||||
length_penalty=options["length_penalty"],
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options["suppress_blank"],
|
||||
suppress_tokens=options["suppress_tokens"],
|
||||
return_scores=True,
|
||||
return_no_speech_prob=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
for res in result:
|
||||
output.append({})
|
||||
# return scores
|
||||
seq_len = len(res.sequences_ids[0])
|
||||
cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"])
|
||||
output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1)
|
||||
|
||||
# return no speech prob
|
||||
output[-1]["no_speech_prob"] = res.no_speech_prob
|
||||
output[-1]["tokens"] = res.sequences_ids[0]
|
||||
|
||||
return encoder_output, output
|
||||
|
||||
def detect_language(
|
||||
self,
|
||||
audio: Optional[np.ndarray] = None,
|
||||
|
||||
Reference in New Issue
Block a user