mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 21:48:08 -05:00
feat: allow passing specific revision to download (#1292)
This commit is contained in:
@@ -596,6 +596,7 @@ class WhisperModel:
|
|||||||
download_root: Optional[str] = None,
|
download_root: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
files: dict = None,
|
files: dict = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
"""Initializes the Whisper model.
|
"""Initializes the Whisper model.
|
||||||
@@ -627,6 +628,9 @@ class WhisperModel:
|
|||||||
files: Load model files from the memory. This argument is a dictionary mapping file names
|
files: Load model files from the memory. This argument is a dictionary mapping file names
|
||||||
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
||||||
identifier for this model.
|
identifier for this model.
|
||||||
|
revision:
|
||||||
|
An optional Git revision id which can be a branch name, a tag, or a
|
||||||
|
commit hash.
|
||||||
"""
|
"""
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
@@ -642,6 +646,7 @@ class WhisperModel:
|
|||||||
model_size_or_path,
|
model_size_or_path,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
cache_dir=download_root,
|
cache_dir=download_root,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = ctranslate2.models.Whisper(
|
self.model = ctranslate2.models.Whisper(
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ def download_model(
|
|||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
||||||
|
|
||||||
@@ -65,6 +66,8 @@ def download_model(
|
|||||||
local_files_only: If True, avoid downloading the file and return the path to the local
|
local_files_only: If True, avoid downloading the file and return the path to the local
|
||||||
cached file if it exists.
|
cached file if it exists.
|
||||||
cache_dir: Path to the folder where cached files are stored.
|
cache_dir: Path to the folder where cached files are stored.
|
||||||
|
revision: An optional Git revision id which can be a branch name, a tag, or a
|
||||||
|
commit hash.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The path to the downloaded model.
|
The path to the downloaded model.
|
||||||
@@ -94,6 +97,7 @@ def download_model(
|
|||||||
"local_files_only": local_files_only,
|
"local_files_only": local_files_only,
|
||||||
"allow_patterns": allow_patterns,
|
"allow_patterns": allow_patterns,
|
||||||
"tqdm_class": disabled_tqdm,
|
"tqdm_class": disabled_tqdm,
|
||||||
|
"revision": revision,
|
||||||
}
|
}
|
||||||
|
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user