3 Commits
310 ... v1.2.1

Author SHA1 Message Date
Mahmoud Ashraf
65882eee9f Bump version to 1.2.1 2025-10-31 14:31:14 +03:00
Mahmoud Ashraf
409a6919f9 Prevent timestamps restoration when clip timestamps are provided in batched inference (#1376) 2025-10-31 14:26:17 +03:00
Mahmoud Ashraf
00a5b26b1f Offload retry logic to hf hub (#1382)
* remove requirement for requests
2025-10-30 22:11:01 +03:00
5 changed files with 43 additions and 24 deletions

View File

@@ -418,23 +418,34 @@ class BatchedInferencePipeline:
"Set 'vad_filter' to True or provide 'clip_timestamps'."
)
clip_timestamps_provided = False
audio_chunks, chunks_metadata = collect_chunks(
audio, clip_timestamps, max_duration=chunk_length
)
else:
clip_timestamps_provided = True
clip_timestamps = [
{k: int(v * sampling_rate) for k, v in segment.items()}
for segment in clip_timestamps
]
audio_chunks, chunks_metadata = [], []
for clip in clip_timestamps:
for i, clip in enumerate(clip_timestamps):
audio_chunks.append(audio[clip["start"] : clip["end"]])
clip_duration = (clip["end"] - clip["start"]) / sampling_rate
if clip_duration > 30:
self.model.logger.warning(
"Segment %d is longer than 30 seconds, "
"only the first 30 seconds will be transcribed",
i,
)
chunks_metadata.append(
{
"offset": clip["start"] / sampling_rate,
"duration": (clip["end"] - clip["start"]) / sampling_rate,
"duration": clip_duration,
"segments": [clip],
}
)
@@ -559,7 +570,10 @@ class BatchedInferencePipeline:
options,
log_progress,
)
segments = restore_speech_timestamps(segments, clip_timestamps, sampling_rate)
if not clip_timestamps_provided:
segments = restore_speech_timestamps(
segments, clip_timestamps, sampling_rate
)
return segments, info

View File

@@ -5,7 +5,6 @@ import re
from typing import List, Optional, Union
import huggingface_hub
import requests
from tqdm.auto import tqdm
@@ -114,24 +113,7 @@ def download_model(
if use_auth_token is not None:
kwargs["token"] = use_auth_token
try:
return huggingface_hub.snapshot_download(repo_id, **kwargs)
except (
huggingface_hub.utils.HfHubHTTPError,
requests.exceptions.ConnectionError,
) as exception:
logger = get_logger()
logger.warning(
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
repo_id,
exception,
)
logger.warning(
"Trying to load the model directly from the local cache, if it exists."
)
kwargs["local_files_only"] = True
return huggingface_hub.snapshot_download(repo_id, **kwargs)
return huggingface_hub.snapshot_download(repo_id, **kwargs)
def format_timestamp(

View File

@@ -1,3 +1,3 @@
"""Version information."""
__version__ = "1.2.0"
__version__ = "1.2.1"

View File

@@ -1,5 +1,5 @@
ctranslate2>=4.0,<5
huggingface_hub>=0.13
huggingface_hub>=0.21
tokenizers>=0.13,<1
onnxruntime>=1.14,<2
av>=11

View File

@@ -290,3 +290,26 @@ def test_cliptimestamps_segments(jfk_path):
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
def test_cliptimestamps_timings(physcisworks_path):
model = WhisperModel("tiny")
pipeline = BatchedInferencePipeline(model=model)
audio = decode_audio(physcisworks_path)
clip_timestamps = [{"start": 0.0, "end": 5.0}, {"start": 6.0, "end": 15.0}]
transcripts = [
" Now I want to return to the conservation of mechanical energy.",
(
" I have here a pendulum. I have an object that weighs 15 kilograms"
" and I can lift it up one meter, which I have done now."
),
]
segments, info = pipeline.transcribe(audio, clip_timestamps=clip_timestamps)
segments = list(segments)
assert len(segments) == 2
for segment, clip, transcript in zip(segments, clip_timestamps, transcripts):
assert clip["start"] == segment.start
assert clip["end"] == segment.end
assert segment.text == transcript