mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-07 12:44:13 -05:00
use jiwer instead of evaluate in benchmarks (#1159)
This commit is contained in:
@@ -5,9 +5,9 @@ import os
|
||||
from io import BytesIO
|
||||
|
||||
from datasets import load_dataset
|
||||
from evaluate import load
|
||||
from jiwer import wer
|
||||
from pytubefix import YouTube
|
||||
from torch.utils.data import DataLoader
|
||||
from pytubefix.exceptions import VideoUnavailable
|
||||
from tqdm import tqdm
|
||||
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
||||
|
||||
@@ -17,15 +17,19 @@ from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
|
||||
def url_to_audio(row):
|
||||
buffer = BytesIO()
|
||||
yt = YouTube(row["link"])
|
||||
video = (
|
||||
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
|
||||
.order_by("bitrate")
|
||||
.desc()
|
||||
.first()
|
||||
)
|
||||
video.stream_to_buffer(buffer)
|
||||
buffer.seek(0)
|
||||
row["audio"] = decode_audio(buffer)
|
||||
try:
|
||||
video = (
|
||||
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
|
||||
.order_by("bitrate")
|
||||
.desc()
|
||||
.last()
|
||||
)
|
||||
video.stream_to_buffer(buffer)
|
||||
buffer.seek(0)
|
||||
row["audio"] = decode_audio(buffer)
|
||||
except VideoUnavailable:
|
||||
print(f'Failed to download: {row["link"]}')
|
||||
row["audio"] = []
|
||||
return row
|
||||
|
||||
|
||||
@@ -39,19 +43,12 @@ parser.add_argument(
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# define the evaluation metric
|
||||
wer_metric = load("wer")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
|
||||
normalizer = EnglishTextNormalizer(json.load(f))
|
||||
|
||||
dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
|
||||
url_to_audio
|
||||
)
|
||||
dataset = iter(
|
||||
DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2)
|
||||
)
|
||||
|
||||
model = WhisperModel("large-v3", device="cuda")
|
||||
pipeline = BatchedInferencePipeline(model, device="cuda")
|
||||
|
||||
@@ -59,7 +56,9 @@ pipeline = BatchedInferencePipeline(model, device="cuda")
|
||||
all_transcriptions = []
|
||||
all_references = []
|
||||
# iterate over the dataset and run inference
|
||||
for i, row in tqdm(enumerate(dataset), desc="Evaluating..."):
|
||||
for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."):
|
||||
if not row["audio"]:
|
||||
continue
|
||||
result, info = pipeline.transcribe(
|
||||
row["audio"][0],
|
||||
batch_size=8,
|
||||
@@ -77,7 +76,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
|
||||
all_references = [normalizer(reference) for reference in all_references]
|
||||
|
||||
# compute the WER metric
|
||||
wer = 100 * wer_metric.compute(
|
||||
predictions=all_transcriptions, references=all_references
|
||||
)
|
||||
print("WER: %.3f" % wer)
|
||||
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
|
||||
print("WER: %.3f" % word_error_rate)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
transformers
|
||||
jiwer
|
||||
evaluate
|
||||
datasets
|
||||
memory_profiler
|
||||
py3nvml
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import os
|
||||
|
||||
from datasets import load_dataset
|
||||
from evaluate import load
|
||||
from jiwer import wer
|
||||
from tqdm import tqdm
|
||||
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
||||
|
||||
@@ -25,9 +25,6 @@ model = WhisperModel(model_path, device="cuda")
|
||||
# load the dataset with streaming mode
|
||||
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
|
||||
|
||||
# define the evaluation metric
|
||||
wer_metric = load("wer")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
|
||||
normalizer = EnglishTextNormalizer(json.load(f))
|
||||
|
||||
@@ -58,7 +55,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
|
||||
all_references = [normalizer(reference) for reference in all_references]
|
||||
|
||||
# compute the WER metric
|
||||
wer = 100 * wer_metric.compute(
|
||||
predictions=all_transcriptions, references=all_references
|
||||
)
|
||||
print("WER: %.3f" % wer)
|
||||
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
|
||||
print("WER: %.3f" % word_error_rate)
|
||||
|
||||
Reference in New Issue
Block a user