use jiwer instead of evaluate in benchmarks (#1159)

This commit is contained in:
Mahmoud Ashraf
2024-11-20 22:51:55 +02:00
committed by GitHub
parent 491852e1b9
commit 9c8ef76c98
3 changed files with 23 additions and 32 deletions

View File

@@ -5,9 +5,9 @@ import os
from io import BytesIO
from datasets import load_dataset
from evaluate import load
from jiwer import wer
from pytubefix import YouTube
from torch.utils.data import DataLoader
from pytubefix.exceptions import VideoUnavailable
from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
@@ -17,15 +17,19 @@ from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
def url_to_audio(row):
buffer = BytesIO()
yt = YouTube(row["link"])
video = (
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
.order_by("bitrate")
.desc()
.first()
)
video.stream_to_buffer(buffer)
buffer.seek(0)
row["audio"] = decode_audio(buffer)
try:
video = (
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
.order_by("bitrate")
.desc()
.last()
)
video.stream_to_buffer(buffer)
buffer.seek(0)
row["audio"] = decode_audio(buffer)
except VideoUnavailable:
print(f'Failed to download: {row["link"]}')
row["audio"] = []
return row
@@ -39,19 +43,12 @@ parser.add_argument(
)
args = parser.parse_args()
# define the evaluation metric
wer_metric = load("wer")
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))
dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
url_to_audio
)
dataset = iter(
DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2)
)
model = WhisperModel("large-v3", device="cuda")
pipeline = BatchedInferencePipeline(model, device="cuda")
@@ -59,7 +56,9 @@ pipeline = BatchedInferencePipeline(model, device="cuda")
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, row in tqdm(enumerate(dataset), desc="Evaluating..."):
for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."):
if not row["audio"]:
continue
result, info = pipeline.transcribe(
row["audio"][0],
batch_size=8,
@@ -77,7 +76,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(
predictions=all_transcriptions, references=all_references
)
print("WER: %.3f" % wer)
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
print("WER: %.3f" % word_error_rate)

View File

@@ -1,6 +1,5 @@
transformers
jiwer
evaluate
datasets
memory_profiler
py3nvml

View File

@@ -3,7 +3,7 @@ import json
import os
from datasets import load_dataset
from evaluate import load
from jiwer import wer
from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
@@ -25,9 +25,6 @@ model = WhisperModel(model_path, device="cuda")
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))
@@ -58,7 +55,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(
predictions=all_transcriptions, references=all_references
)
print("WER: %.3f" % wer)
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
print("WER: %.3f" % word_error_rate)