Cleanup of BatchedInferencePipeline (#1135)

This commit is contained in:
Mahmoud Ashraf
2024-11-17 15:45:32 +02:00
committed by GitHub
parent a6f8fbae00
commit be9fb36ed3

View File

@@ -57,7 +57,7 @@ class Segment:
compression_ratio: float
no_speech_prob: float
words: Optional[List[Word]]
temperature: Optional[float] = 1.0
temperature: Optional[float]
def _asdict(self):
warn(
@@ -68,7 +68,6 @@ class Segment:
return asdict(self)
# Added additional parameters for multilingual videos and fixes below
@dataclass
class TranscriptionOptions:
beam_size: int
@@ -112,34 +111,17 @@ class TranscriptionInfo:
vad_options: VadOptions
# The code below is originally from HF pipeline and is used in whisper-x
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
class BatchedInferencePipeline:
"""
Huggingface Pipeline wrapper for WhisperModel.
Copyright (c) 2022, Max Bain
All rights reserved.
Modified by Mobius Labs GmbH
"""
def __init__(
self,
model,
options: Optional[TranscriptionOptions] = None,
tokenizer=None,
language: Optional[str] = None,
):
self.model: WhisperModel = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.last_speech_timestamp = 0.0
def forward(self, features, chunks_metadata, **forward_params):
encoder_output, outputs = self.model.generate_segment_batched(
features, self.tokenizer, forward_params
def forward(self, features, tokenizer, chunks_metadata, options):
encoder_output, outputs = self.generate_segment_batched(
features, tokenizer, options
)
segmented_outputs = []
@@ -153,7 +135,7 @@ class BatchedInferencePipeline:
seek,
single_timestamp_ending,
) = self.model._split_segments_by_timestamps(
tokenizer=self.tokenizer,
tokenizer=tokenizer,
tokens=output["tokens"],
time_offset=chunk_metadata["start_time"],
segment_size=segment_size,
@@ -163,14 +145,14 @@ class BatchedInferencePipeline:
segmented_outputs.append(
[
dict(
text=self.tokenizer.decode(subsegment["tokens"]),
text=tokenizer.decode(subsegment["tokens"]),
avg_logprob=output["avg_logprob"],
no_speech_prob=output["no_speech_prob"],
tokens=subsegment["tokens"],
start=subsegment["start"],
end=subsegment["end"],
compression_ratio=get_compression_ratio(
self.tokenizer.decode(subsegment["tokens"])
tokenizer.decode(subsegment["tokens"])
),
seek=int(
chunk_metadata["start_time"] * self.model.frames_per_second
@@ -179,19 +161,88 @@ class BatchedInferencePipeline:
for subsegment in subsegments
]
)
if forward_params["word_timestamps"]:
if options.word_timestamps:
self.last_speech_timestamp = self.model.add_word_timestamps(
segmented_outputs,
self.tokenizer,
tokenizer,
encoder_output,
segment_sizes,
forward_params["prepend_punctuations"],
forward_params["append_punctuations"],
options.prepend_punctuations,
options.append_punctuations,
self.last_speech_timestamp,
)
return segmented_outputs
def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
):
batch_size = features.shape[0]
prompt = self.model.get_prompt(
tokenizer,
previous_tokens=(
tokenizer.encode(options.initial_prompt)
if options.initial_prompt is not None
else []
),
without_timestamps=options.without_timestamps,
hotwords=options.hotwords,
)
if options.max_new_tokens is not None:
max_length = len(prompt) + options.max_new_tokens
else:
max_length = self.model.max_length
if max_length > self.model.max_length:
raise ValueError(
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
f"and `max_new_tokens` is: {max_length}. This exceeds the "
f"`max_length` of the Whisper model: {self.model.max_length}. "
"You should either reduce the length of your prompt, or "
"reduce the value of `max_new_tokens`, "
f"so that their combined length is less that {self.model.max_length}."
)
encoder_output = self.model.encode(features)
results = self.model.model.generate(
encoder_output,
[prompt] * batch_size,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
return_scores=True,
return_no_speech_prob=True,
sampling_temperature=options.temperatures[0],
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
)
output = []
for result in results:
# return scores
seq_len = len(result.sequences_ids[0])
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
output.append(
dict(
avg_logprob=cum_logprob / (seq_len + 1),
no_speech_prob=result.no_speech_prob,
tokens=result.sequences_ids[0],
)
)
return encoder_output, output
def transcribe(
self,
audio: Union[str, BinaryIO, np.ndarray],
@@ -216,20 +267,26 @@ class BatchedInferencePipeline:
log_prob_threshold: Optional[float] = -1.0,
log_prob_low_threshold: Optional[float] = None,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
prompt_reset_on_temperature: float = 0.5,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
prefix: Optional[str] = None,
suppress_blank: bool = True,
suppress_tokens: Optional[List[int]] = [-1],
without_timestamps: bool = True,
max_initial_timestamp: float = 1.0,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16,
hallucination_silence_threshold: Optional[float] = None,
batch_size: int = 8,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_segments: int = 1,
@@ -250,22 +307,10 @@ class BatchedInferencePipeline:
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
temperature: Temperature for sampling. If a list or tuple is passed,
only the first value is used.
initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
prompt for the each window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in `tokenizer.non_speech_tokens()`.
@@ -296,29 +341,32 @@ class BatchedInferencePipeline:
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Static params: (Fixed for batched version)
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
multilingual: If True, perform transcription on multilingual videos. Set as False.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription). set as None.
Unused Arguments
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
condition_on_previous_text: If True, the previous output of the model is provided
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync. Set as False
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
Arg has effect only if condition_on_previous_text is True. Set at 0.5
#TODO: support "hallucination_silence_threshold" when "word_timestamps=True"
prefix: Optional text to provide as a prefix at the beginning of each window.
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
multilingual: If True, perform transcription on multilingual videos. Set as False.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription). set as None.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
unused:
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
A tuple with:
@@ -410,7 +458,7 @@ class BatchedInferencePipeline:
language_probability = 1
self.tokenizer = Tokenizer(
tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
@@ -421,8 +469,7 @@ class BatchedInferencePipeline:
np.stack([pad_or_trim(feature) for feature in features]) if features else []
)
# batched options: see the difference with default options in WhisperModel
batched_options = TranscriptionOptions(
options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
@@ -434,12 +481,14 @@ class BatchedInferencePipeline:
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
temperature[:1]
if isinstance(temperature, (list, tuple))
else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens),
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
max_new_tokens=max_new_tokens,
@@ -447,7 +496,7 @@ class BatchedInferencePipeline:
word_timestamps=word_timestamps,
hallucination_silence_threshold=None,
condition_on_previous_text=False,
clip_timestamps="0",
clip_timestamps=clip_timestamps,
prompt_reset_on_temperature=0.5,
multilingual=False,
output_language=None,
@@ -460,31 +509,33 @@ class BatchedInferencePipeline:
language_probability=language_probability,
duration=duration,
duration_after_vad=duration_after_vad,
transcription_options=batched_options,
vad_options=None,
transcription_options=options,
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)
segments = self._batched_segments_generator(
features,
tokenizer,
chunks_metadata,
batch_size,
batched_options,
options,
log_progress,
)
return segments, info
def _batched_segments_generator(
self, features, chunks_metadata, batch_size, options, log_progress
self, features, tokenizer, chunks_metadata, batch_size, options, log_progress
):
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
seg_idx = 0
for i in range(0, len(features), batch_size):
results = self.forward(
features[i : i + batch_size],
tokenizer,
chunks_metadata[i : i + batch_size],
**asdict(options),
options,
)
for result in results:
@@ -505,6 +556,7 @@ class BatchedInferencePipeline:
avg_logprob=segment["avg_logprob"],
no_speech_prob=segment["no_speech_prob"],
compression_ratio=segment["compression_ratio"],
temperature=options.temperatures[0],
)
pbar.update(1)
@@ -1689,57 +1741,6 @@ class WhisperModel:
)
return return_list
def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: dict,
):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
if options["initial_prompt"] is not None:
initial_prompt = " " + options["initial_prompt"].strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
previous_tokens,
without_timestamps=options["without_timestamps"],
prefix=options["prefix"],
)
encoder_output = self.encode(features)
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
beam_size=options["beam_size"],
patience=options["patience"],
length_penalty=options["length_penalty"],
max_length=self.max_length,
suppress_blank=options["suppress_blank"],
suppress_tokens=options["suppress_tokens"],
return_scores=True,
return_no_speech_prob=True,
)
output = []
for res in result:
output.append({})
# return scores
seq_len = len(res.sequences_ids[0])
cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"])
output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1)
# return no speech prob
output[-1]["no_speech_prob"] = res.no_speech_prob
output[-1]["tokens"] = res.sequences_ids[0]
return encoder_output, output
def detect_language(
self,
audio: Optional[np.ndarray] = None,