feat: Allow loading of private HF models (#1309)

* feat: add HuggingFace auth token support to model download

* Format
This commit is contained in:
Rishil
2025-06-02 12:12:34 +01:00
committed by GitHub
parent 43d4163fe0
commit d3bfd0a305
2 changed files with 11 additions and 1 deletions

View File

@@ -597,6 +597,7 @@ class WhisperModel:
local_files_only: bool = False,
files: dict = None,
revision: Optional[str] = None,
use_auth_token: Optional[Union[str, bool]] = None,
**model_kwargs,
):
"""Initializes the Whisper model.
@@ -631,6 +632,8 @@ class WhisperModel:
revision:
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
use_auth_token: HuggingFace authentication token or True to use the
token stored by the HuggingFace config folder.
"""
self.logger = get_logger()
@@ -647,6 +650,7 @@ class WhisperModel:
local_files_only=local_files_only,
cache_dir=download_root,
revision=revision,
use_auth_token=use_auth_token,
)
self.model = ctranslate2.models.Whisper(

View File

@@ -2,7 +2,7 @@ import logging
import os
import re
from typing import List, Optional
from typing import List, Optional, Union
import huggingface_hub
import requests
@@ -53,6 +53,7 @@ def download_model(
local_files_only: bool = False,
cache_dir: Optional[str] = None,
revision: Optional[str] = None,
use_auth_token: Optional[Union[str, bool]] = None,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
@@ -69,6 +70,8 @@ def download_model(
cache_dir: Path to the folder where cached files are stored.
revision: An optional Git revision id which can be a branch name, a tag, or a
commit hash.
use_auth_token: HuggingFace authentication token or True to use the
token stored by the HuggingFace config folder.
Returns:
The path to the downloaded model.
@@ -108,6 +111,9 @@ def download_model(
if cache_dir is not None:
kwargs["cache_dir"] = cache_dir
if use_auth_token is not None:
kwargs["token"] = use_auth_token
try:
return huggingface_hub.snapshot_download(repo_id, **kwargs)
except (