mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-12 23:18:06 -05:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7808eddf06 | ||
|
|
de7682a2f0 | ||
|
|
523ae2180f | ||
|
|
2b7be47041 | ||
|
|
3f02c53610 | ||
|
|
e663186a4b | ||
|
|
e44a8c7ba0 | ||
|
|
33f41d84e3 |
34
.github/workflows/ci.yml
vendored
34
.github/workflows/ci.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
- name: Install module
|
||||
run: |
|
||||
pip install wheel
|
||||
pip install .[dev] --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Check code format with Black
|
||||
run: |
|
||||
@@ -55,8 +55,36 @@ jobs:
|
||||
- name: Install module
|
||||
run: |
|
||||
pip install wheel
|
||||
pip install .[dev] --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Run pytest
|
||||
run: |
|
||||
pytest -v tests/test.py
|
||||
pytest -v tests/
|
||||
|
||||
|
||||
build-and-push-package:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [check-code-format, run-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install wheel
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
python3 setup.py sdist bdist_wheel
|
||||
|
||||
- name: Push package on PyPI
|
||||
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
||||
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
*.pyc
|
||||
67
README.md
67
README.md
@@ -1,12 +1,14 @@
|
||||
[](https://github.com/guillaumekln/faster-whisper/actions?query=workflow%3ACI) [](https://badge.fury.io/py/faster-whisper)
|
||||
|
||||
# Faster Whisper transcription with CTranslate2
|
||||
|
||||
This repository demonstrates how to implement the Whisper transcription using [CTranslate2](https://github.com/OpenNMT/CTranslate2/), which is a fast inference engine for Transformer models.
|
||||
**faster-whisper** is a reimplementation of OpenAI's Whisper model using [CTranslate2](https://github.com/OpenNMT/CTranslate2/), which is a fast inference engine for Transformer models.
|
||||
|
||||
This implementation is up to 4 times faster than [openai/whisper](https://github.com/openai/whisper) for the same accuracy while using less memory. The efficiency can be further improved with 8-bit quantization on both CPU and GPU.
|
||||
|
||||
## Benchmark
|
||||
|
||||
For reference, here's the time and memory usage that are required to transcribe **13 minutes** of audio using different implementations:
|
||||
For reference, here's the time and memory usage that are required to transcribe [**13 minutes**](https://www.youtube.com/watch?v=0u7tTptBo9I) of audio using different implementations:
|
||||
|
||||
* [openai/whisper](https://github.com/openai/whisper)@[6dea21fd](https://github.com/openai/whisper/commit/6dea21fd7f7253bfe450f1e2512a0fe47ee2d258)
|
||||
* [whisper.cpp](https://github.com/ggerganov/whisper.cpp)@[3b010f9](https://github.com/ggerganov/whisper.cpp/commit/3b010f9bed9a6068609e9faf52383aea792b0362)
|
||||
@@ -36,24 +38,24 @@ For reference, here's the time and memory usage that are required to transcribe
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -e .[conversion]
|
||||
```
|
||||
|
||||
The model conversion requires the modules `transformers` and `torch` which are installed by the `[conversion]` requirement. Once a model is converted, these modules are no longer needed and the installation could be simplified to:
|
||||
The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/):
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
pip install faster-whisper
|
||||
```
|
||||
|
||||
It is also possible to install the module without cloning the Git repository:
|
||||
**Other installation methods:**
|
||||
|
||||
```bash
|
||||
# Install the master branch:
|
||||
pip install "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/refs/heads/master.tar.gz"
|
||||
pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/refs/heads/master.tar.gz"
|
||||
|
||||
# Install a specific commit:
|
||||
pip install "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/a4f1cc8f11433e454c3934442b5e1a4ed5e865c3.tar.gz"
|
||||
pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/a4f1cc8f11433e454c3934442b5e1a4ed5e865c3.tar.gz"
|
||||
|
||||
# Install for development:
|
||||
git clone https://github.com/guillaumekln/faster-whisper.git
|
||||
pip install -e faster-whisper/
|
||||
```
|
||||
|
||||
### GPU support
|
||||
@@ -62,35 +64,20 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
|
||||
|
||||
## Usage
|
||||
|
||||
### Model conversion
|
||||
|
||||
A Whisper model should be first converted into the CTranslate2 format. We provide a script to download and convert models from the [Hugging Face model repository](https://huggingface.co/models?sort=downloads&search=whisper).
|
||||
|
||||
For example the command below converts the "large-v2" Whisper model and saves the weights in FP16:
|
||||
|
||||
```bash
|
||||
ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \
|
||||
--copy_files tokenizer.json --quantization float16
|
||||
```
|
||||
|
||||
If the option `--copy_files tokenizer.json` is not used, the tokenizer configuration is automatically downloaded when the model is loaded later.
|
||||
|
||||
Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).
|
||||
|
||||
### Transcription
|
||||
|
||||
```python
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model_path = "whisper-large-v2-ct2/"
|
||||
model_size = "large-v2"
|
||||
|
||||
# Run on GPU with FP16
|
||||
model = WhisperModel(model_path, device="cuda", compute_type="float16")
|
||||
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
|
||||
# or run on GPU with INT8
|
||||
# model = WhisperModel(model_path, device="cuda", compute_type="int8_float16")
|
||||
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
|
||||
# or run on CPU with INT8
|
||||
# model = WhisperModel(model_path, device="cpu", compute_type="int8")
|
||||
# model = WhisperModel(model_size, device="cpu", compute_type="int8")
|
||||
|
||||
segments, info = model.transcribe("audio.mp3", beam_size=5)
|
||||
|
||||
@@ -112,6 +99,26 @@ for segment in segments:
|
||||
|
||||
See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation.
|
||||
|
||||
## Model conversion
|
||||
|
||||
When loading a model from its size such as `WhisperModel("large-v2")`, the correspondig CTranslate2 model is automatically downloaded from the [Hugging Face Hub](https://huggingface.co/guillaumekln).
|
||||
|
||||
We also provide a script to convert any Whisper models compatible with the Transformers library. They could be the original OpenAI models or user fine-tuned models.
|
||||
|
||||
For example the command below converts the [original "large-v2" Whisper model](https://huggingface.co/openai/whisper-large-v2) and saves the weights in FP16:
|
||||
|
||||
```bash
|
||||
pip install transformers[torch]>=4.23
|
||||
|
||||
ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \
|
||||
--copy_files tokenizer.json --quantization float16
|
||||
```
|
||||
|
||||
* The option `--model` accepts a model name on the Hub or a path to a model directory.
|
||||
* If the option `--copy_files tokenizer.json` is not used, the tokenizer configuration is automatically downloaded when the model is loaded later.
|
||||
|
||||
Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).
|
||||
|
||||
## Comparing performance against other implementations
|
||||
|
||||
If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.transcribe import WhisperModel
|
||||
from faster_whisper.utils import format_timestamp
|
||||
from faster_whisper.utils import download_model, format_timestamp
|
||||
|
||||
__all__ = [
|
||||
"decode_audio",
|
||||
"WhisperModel",
|
||||
"download_model",
|
||||
"format_timestamp",
|
||||
]
|
||||
|
||||
@@ -11,6 +11,7 @@ import tokenizers
|
||||
from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.utils import download_model
|
||||
|
||||
|
||||
class Word(NamedTuple):
|
||||
@@ -57,7 +58,7 @@ class TranscriptionOptions(NamedTuple):
|
||||
class WhisperModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
model_size_or_path: str,
|
||||
device: str = "auto",
|
||||
device_index: Union[int, List[int]] = 0,
|
||||
compute_type: str = "default",
|
||||
@@ -67,7 +68,9 @@ class WhisperModel:
|
||||
"""Initializes the Whisper model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the converted model.
|
||||
model_size_or_path: Size of the model to use (e.g. "large-v2", "small", "tiny.en", etc.)
|
||||
or a path to a converted model directory. When a size is configured, the converted
|
||||
model is downloaded from the Hugging Face Hub.
|
||||
device: Device to use for computation ("cpu", "cuda", "auto").
|
||||
device_index: Device ID to use.
|
||||
The model can also be loaded on multiple GPUs by passing a list of IDs
|
||||
@@ -82,6 +85,11 @@ class WhisperModel:
|
||||
(concurrent calls to self.model.generate() will run in parallel).
|
||||
This can improve the global throughput at the cost of increased memory usage.
|
||||
"""
|
||||
if os.path.isdir(model_size_or_path):
|
||||
model_path = model_size_or_path
|
||||
else:
|
||||
model_path = download_model(model_size_or_path)
|
||||
|
||||
self.model = ctranslate2.models.Whisper(
|
||||
model_path,
|
||||
device=device,
|
||||
@@ -196,14 +204,16 @@ class WhisperModel:
|
||||
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
whisper_encoder = WhisperEncoder(self.model)
|
||||
|
||||
if language is None:
|
||||
if not self.model.is_multilingual:
|
||||
language = "en"
|
||||
language_probability = 1
|
||||
else:
|
||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||
input = get_ctranslate2_storage(segment)
|
||||
results = self.model.detect_language(input)
|
||||
encoder_output = whisper_encoder(0, segment)
|
||||
results = self.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
else:
|
||||
@@ -239,7 +249,7 @@ class WhisperModel:
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options)
|
||||
segments = self.generate_segments(features, whisper_encoder, tokenizer, options)
|
||||
|
||||
audio_info = AudioInfo(
|
||||
language=language,
|
||||
@@ -252,6 +262,7 @@ class WhisperModel:
|
||||
def generate_segments(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
whisper_encoder: "WhisperEncoder",
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Iterable[Segment]:
|
||||
@@ -281,8 +292,10 @@ class WhisperModel:
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
encoder_output = whisper_encoder(seek, segment)
|
||||
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
segment, prompt, tokenizer, options
|
||||
encoder_output, prompt, tokenizer, options
|
||||
)
|
||||
|
||||
if options.no_speech_threshold is not None:
|
||||
@@ -388,7 +401,7 @@ class WhisperModel:
|
||||
self.add_word_timestamps(
|
||||
current_segments,
|
||||
tokenizer,
|
||||
segment,
|
||||
encoder_output,
|
||||
segment_size,
|
||||
options.prepend_punctuations,
|
||||
options.append_punctuations,
|
||||
@@ -428,12 +441,11 @@ class WhisperModel:
|
||||
|
||||
def generate_with_fallback(
|
||||
self,
|
||||
segment: np.ndarray,
|
||||
encoder_output: ctranslate2.StorageView,
|
||||
prompt: List[int],
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
||||
features = get_ctranslate2_storage(segment)
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
final_temperature = None
|
||||
@@ -458,7 +470,7 @@ class WhisperModel:
|
||||
|
||||
final_temperature = temperature
|
||||
result = self.model.generate(
|
||||
features,
|
||||
encoder_output,
|
||||
[prompt],
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
@@ -529,7 +541,7 @@ class WhisperModel:
|
||||
self,
|
||||
segments: List[dict],
|
||||
tokenizer: Tokenizer,
|
||||
mel: np.ndarray,
|
||||
encoder_output: ctranslate2.StorageView,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str,
|
||||
append_punctuations: str,
|
||||
@@ -543,7 +555,9 @@ class WhisperModel:
|
||||
]
|
||||
|
||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||
alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames)
|
||||
alignment = self.find_alignment(
|
||||
tokenizer, text_tokens, encoder_output, num_frames
|
||||
)
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = (
|
||||
@@ -585,7 +599,7 @@ class WhisperModel:
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: np.ndarray,
|
||||
encoder_output: ctranslate2.StorageView,
|
||||
num_frames: int,
|
||||
median_filter_width: int = 7,
|
||||
) -> List[dict]:
|
||||
@@ -593,7 +607,7 @@ class WhisperModel:
|
||||
return []
|
||||
|
||||
result = self.model.align(
|
||||
get_ctranslate2_storage(mel),
|
||||
encoder_output,
|
||||
tokenizer.sot_sequence,
|
||||
[text_tokens],
|
||||
num_frames,
|
||||
@@ -646,9 +660,39 @@ class WhisperModel:
|
||||
]
|
||||
|
||||
|
||||
class WhisperEncoder:
|
||||
"""Helper class to cache and reuse the encoder output."""
|
||||
|
||||
def __init__(self, model: ctranslate2.models.Whisper):
|
||||
self.model = model
|
||||
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
self.cache_on_cpu = len(model.device_index) > 1
|
||||
|
||||
self.last_seek = -1
|
||||
self.last_output = None
|
||||
|
||||
def __call__(self, seek: int, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
if self.last_seek == seek:
|
||||
return self.last_output
|
||||
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
output = self.model.encode(features, to_cpu=self.cache_on_cpu)
|
||||
|
||||
if self.last_output is not None:
|
||||
del self.last_output
|
||||
|
||||
self.last_seek = seek
|
||||
self.last_output = output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = np.expand_dims(segment, 0)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
return segment
|
||||
|
||||
|
||||
@@ -1,3 +1,42 @@
|
||||
from typing import Optional
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def download_model(
|
||||
size: str,
|
||||
output_dir: Optional[str] = None,
|
||||
show_progress_bars: bool = True,
|
||||
):
|
||||
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
||||
|
||||
The model is downloaded from https://huggingface.co/guillaumekln.
|
||||
|
||||
Args:
|
||||
size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
|
||||
medium, medium.en, or large-v2).
|
||||
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
||||
the standard Hugging Face cache directory.
|
||||
show_progress_bars: Show the tqdm progress bars during the download.
|
||||
|
||||
Returns:
|
||||
The path to the downloaded model.
|
||||
"""
|
||||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||
kwargs = {}
|
||||
|
||||
if output_dir is not None:
|
||||
kwargs["local_dir"] = output_dir
|
||||
kwargs["local_dir_use_symlinks"] = False
|
||||
|
||||
if not show_progress_bars:
|
||||
kwargs["tqdm_class"] = disabled_tqdm
|
||||
|
||||
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
seconds: float,
|
||||
always_include_hours: bool = False,
|
||||
@@ -19,3 +58,9 @@ def format_timestamp(
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
class disabled_tqdm(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["disable"] = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
av==10.*
|
||||
ctranslate2>=3.9,<4
|
||||
ctranslate2>=3.10,<4
|
||||
huggingface_hub>=0.13
|
||||
tokenizers==0.13.*
|
||||
|
||||
5
setup.py
5
setup.py
@@ -23,7 +23,7 @@ conversion_requires = get_requirements(
|
||||
|
||||
setup(
|
||||
name="faster-whisper",
|
||||
version="0.2.0",
|
||||
version="0.3.0",
|
||||
license="MIT",
|
||||
description="Faster Whisper transcription with CTranslate2",
|
||||
long_description=get_long_description(),
|
||||
@@ -48,8 +48,7 @@ setup(
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
"conversion": conversion_requires,
|
||||
"dev": conversion_requires
|
||||
+ [
|
||||
"dev": [
|
||||
"black==23.*",
|
||||
"flake8==6.*",
|
||||
"isort==5.*",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
import ctranslate2
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -12,20 +11,3 @@ def data_dir():
|
||||
@pytest.fixture
|
||||
def jfk_path(data_dir):
|
||||
return os.path.join(data_dir, "jfk.flac")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tiny_model_dir(tmp_path_factory):
|
||||
model_path = str(tmp_path_factory.mktemp("data") / "model")
|
||||
convert_model("tiny", model_path)
|
||||
return model_path
|
||||
|
||||
|
||||
def convert_model(size, output_dir):
|
||||
name = "openai/whisper-%s" % size
|
||||
|
||||
ctranslate2.converters.TransformersConverter(
|
||||
name,
|
||||
copy_files=["tokenizer.json"],
|
||||
load_as_float16=True,
|
||||
).convert(output_dir, quantization="float16")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
|
||||
def test_transcribe(tiny_model_dir, jfk_path):
|
||||
model = WhisperModel(tiny_model_dir)
|
||||
def test_transcribe(jfk_path):
|
||||
model = WhisperModel("tiny")
|
||||
segments, info = model.transcribe(jfk_path, word_timestamps=True)
|
||||
|
||||
assert info.language == "en"
|
||||
17
tests/test_utils.py
Normal file
17
tests/test_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
|
||||
from faster_whisper import download_model
|
||||
|
||||
|
||||
def test_download_model(tmpdir):
|
||||
output_dir = str(tmpdir.join("model"))
|
||||
|
||||
model_dir = download_model("tiny", output_dir=output_dir)
|
||||
|
||||
assert model_dir == output_dir
|
||||
assert os.path.isdir(model_dir)
|
||||
assert not os.path.islink(model_dir)
|
||||
|
||||
for filename in os.listdir(model_dir):
|
||||
path = os.path.join(model_dir, filename)
|
||||
assert not os.path.islink(path)
|
||||
Reference in New Issue
Block a user