mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 21:48:08 -05:00
Automatically download converted models from the Hugging Face Hub (#70)
* Automatically download converted models from the Hugging Face Hub * Remove unused import * Remove non needed requirements in dev mode * Remove extra index URL when pip install in CI * Allow downloading to a specific directory * Update docstring * Add argument to disable the progess bars * Fix typo in docstring
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.transcribe import WhisperModel
|
||||
from faster_whisper.utils import format_timestamp
|
||||
from faster_whisper.utils import download_model, format_timestamp
|
||||
|
||||
__all__ = [
|
||||
"decode_audio",
|
||||
"WhisperModel",
|
||||
"download_model",
|
||||
"format_timestamp",
|
||||
]
|
||||
|
||||
@@ -11,6 +11,7 @@ import tokenizers
|
||||
from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.utils import download_model
|
||||
|
||||
|
||||
class Word(NamedTuple):
|
||||
@@ -57,7 +58,7 @@ class TranscriptionOptions(NamedTuple):
|
||||
class WhisperModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
model_size_or_path: str,
|
||||
device: str = "auto",
|
||||
device_index: Union[int, List[int]] = 0,
|
||||
compute_type: str = "default",
|
||||
@@ -67,7 +68,9 @@ class WhisperModel:
|
||||
"""Initializes the Whisper model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the converted model.
|
||||
model_size_or_path: Size of the model to use (e.g. "large-v2", "small", "tiny.en", etc.)
|
||||
or a path to a converted model directory. When a size is configured, the converted
|
||||
model is downloaded from the Hugging Face Hub.
|
||||
device: Device to use for computation ("cpu", "cuda", "auto").
|
||||
device_index: Device ID to use.
|
||||
The model can also be loaded on multiple GPUs by passing a list of IDs
|
||||
@@ -82,6 +85,11 @@ class WhisperModel:
|
||||
(concurrent calls to self.model.generate() will run in parallel).
|
||||
This can improve the global throughput at the cost of increased memory usage.
|
||||
"""
|
||||
if os.path.isdir(model_size_or_path):
|
||||
model_path = model_size_or_path
|
||||
else:
|
||||
model_path = download_model(model_size_or_path)
|
||||
|
||||
self.model = ctranslate2.models.Whisper(
|
||||
model_path,
|
||||
device=device,
|
||||
|
||||
@@ -1,3 +1,42 @@
|
||||
from typing import Optional
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def download_model(
|
||||
size: str,
|
||||
output_dir: Optional[str] = None,
|
||||
show_progress_bars: bool = True,
|
||||
):
|
||||
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
||||
|
||||
The model is downloaded from https://huggingface.co/guillaumekln.
|
||||
|
||||
Args:
|
||||
size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
|
||||
medium, medium.en, or large-v2).
|
||||
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
||||
the standard Hugging Face cache directory.
|
||||
show_progress_bars: Show the tqdm progress bars during the download.
|
||||
|
||||
Returns:
|
||||
The path to the downloaded model.
|
||||
"""
|
||||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||
kwargs = {}
|
||||
|
||||
if output_dir is not None:
|
||||
kwargs["local_dir"] = output_dir
|
||||
kwargs["local_dir_use_symlinks"] = False
|
||||
|
||||
if not show_progress_bars:
|
||||
kwargs["tqdm_class"] = disabled_tqdm
|
||||
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
seconds: float,
|
||||
always_include_hours: bool = False,
|
||||
@@ -19,3 +58,9 @@ def format_timestamp(
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
class disabled_tqdm(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["disable"] = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user