mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-08 13:14:00 -05:00
feat: Allow loading of private HF models (#1309)
* feat: add HuggingFace auth token support to model download * Format
This commit is contained in:
@@ -597,6 +597,7 @@ class WhisperModel:
|
||||
local_files_only: bool = False,
|
||||
files: dict = None,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""Initializes the Whisper model.
|
||||
@@ -631,6 +632,8 @@ class WhisperModel:
|
||||
revision:
|
||||
An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash.
|
||||
use_auth_token: HuggingFace authentication token or True to use the
|
||||
token stored by the HuggingFace config folder.
|
||||
"""
|
||||
self.logger = get_logger()
|
||||
|
||||
@@ -647,6 +650,7 @@ class WhisperModel:
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=download_root,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
self.model = ctranslate2.models.Whisper(
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
@@ -53,6 +53,7 @@ def download_model(
|
||||
local_files_only: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
||||
|
||||
@@ -69,6 +70,8 @@ def download_model(
|
||||
cache_dir: Path to the folder where cached files are stored.
|
||||
revision: An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash.
|
||||
use_auth_token: HuggingFace authentication token or True to use the
|
||||
token stored by the HuggingFace config folder.
|
||||
|
||||
Returns:
|
||||
The path to the downloaded model.
|
||||
@@ -108,6 +111,9 @@ def download_model(
|
||||
if cache_dir is not None:
|
||||
kwargs["cache_dir"] = cache_dir
|
||||
|
||||
if use_auth_token is not None:
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
try:
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
except (
|
||||
|
||||
Reference in New Issue
Block a user