mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-08 13:14:00 -05:00
Add progress bar to WhisperModel.transcribe (#1138)
This commit is contained in:
@@ -646,6 +646,7 @@ class WhisperModel:
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
language: Optional[str] = None,
|
||||
task: str = "transcribe",
|
||||
log_progress: bool = False,
|
||||
beam_size: int = 5,
|
||||
best_of: int = 5,
|
||||
patience: float = 1,
|
||||
@@ -695,6 +696,7 @@ class WhisperModel:
|
||||
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
|
||||
of audio.
|
||||
task: Task to execute (transcribe or translate).
|
||||
log_progress: whether to show progress bar or not.
|
||||
beam_size: Beam size to use for decoding.
|
||||
best_of: Number of candidates when sampling with non-zero temperature.
|
||||
patience: Beam search patience factor.
|
||||
@@ -941,7 +943,9 @@ class WhisperModel:
|
||||
hotwords=hotwords,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||
segments = self.generate_segments(
|
||||
features, tokenizer, options, log_progress, encoder_output
|
||||
)
|
||||
|
||||
if speech_chunks:
|
||||
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
|
||||
@@ -1041,6 +1045,7 @@ class WhisperModel:
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
log_progress,
|
||||
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||
) -> Iterable[Segment]:
|
||||
content_frames = features.shape[-1] - 1
|
||||
@@ -1083,6 +1088,7 @@ class WhisperModel:
|
||||
else:
|
||||
all_tokens.extend(options.initial_prompt)
|
||||
|
||||
pbar = tqdm(total=content_duration, unit="seconds", disable=not log_progress)
|
||||
last_speech_timestamp = 0.0
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
@@ -1341,6 +1347,12 @@ class WhisperModel:
|
||||
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
pbar.update(
|
||||
(min(content_frames, seek) - previous_seek)
|
||||
* self.feature_extractor.time_per_frame,
|
||||
)
|
||||
pbar.close()
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
|
||||
Reference in New Issue
Block a user