feat: allow passing specific revision to download (#1292)

This commit is contained in:
Felix Mosheev
2025-04-30 00:55:48 +03:00
committed by GitHub
parent 1383fd4d37
commit 700584b2e6
2 changed files with 9 additions and 0 deletions

View File

@@ -596,6 +596,7 @@ class WhisperModel:
download_root: Optional[str] = None,
local_files_only: bool = False,
files: dict = None,
revision: Optional[str] = None,
**model_kwargs,
):
"""Initializes the Whisper model.
@@ -627,6 +628,9 @@ class WhisperModel:
files: Load model files from the memory. This argument is a dictionary mapping file names
to file contents as file-like or bytes objects. If this is set, model_path acts as an
identifier for this model.
revision:
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
"""
self.logger = get_logger()
@@ -642,6 +646,7 @@ class WhisperModel:
model_size_or_path,
local_files_only=local_files_only,
cache_dir=download_root,
revision=revision,
)
self.model = ctranslate2.models.Whisper(

View File

@@ -51,6 +51,7 @@ def download_model(
output_dir: Optional[str] = None,
local_files_only: bool = False,
cache_dir: Optional[str] = None,
revision: Optional[str] = None,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
@@ -65,6 +66,8 @@ def download_model(
local_files_only: If True, avoid downloading the file and return the path to the local
cached file if it exists.
cache_dir: Path to the folder where cached files are stored.
revision: An optional Git revision id which can be a branch name, a tag, or a
commit hash.
Returns:
The path to the downloaded model.
@@ -94,6 +97,7 @@ def download_model(
"local_files_only": local_files_only,
"allow_patterns": allow_patterns,
"tqdm_class": disabled_tqdm,
"revision": revision,
}
if output_dir is not None: