use jiwer instead of evaluate in benchmarks (#1159)

This commit is contained in:
Mahmoud Ashraf
2024-11-20 22:51:55 +02:00
committed by GitHub
parent 491852e1b9
commit 9c8ef76c98
3 changed files with 23 additions and 32 deletions

View File

@@ -5,9 +5,9 @@ import os
from io import BytesIO from io import BytesIO
from datasets import load_dataset from datasets import load_dataset
from evaluate import load from jiwer import wer
from pytubefix import YouTube from pytubefix import YouTube
from torch.utils.data import DataLoader from pytubefix.exceptions import VideoUnavailable
from tqdm import tqdm from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
@@ -17,15 +17,19 @@ from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
def url_to_audio(row): def url_to_audio(row):
buffer = BytesIO() buffer = BytesIO()
yt = YouTube(row["link"]) yt = YouTube(row["link"])
video = ( try:
yt.streams.filter(only_audio=True, mime_type="audio/mp4") video = (
.order_by("bitrate") yt.streams.filter(only_audio=True, mime_type="audio/mp4")
.desc() .order_by("bitrate")
.first() .desc()
) .last()
video.stream_to_buffer(buffer) )
buffer.seek(0) video.stream_to_buffer(buffer)
row["audio"] = decode_audio(buffer) buffer.seek(0)
row["audio"] = decode_audio(buffer)
except VideoUnavailable:
print(f'Failed to download: {row["link"]}')
row["audio"] = []
return row return row
@@ -39,19 +43,12 @@ parser.add_argument(
) )
args = parser.parse_args() args = parser.parse_args()
# define the evaluation metric
wer_metric = load("wer")
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f)) normalizer = EnglishTextNormalizer(json.load(f))
dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map( dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
url_to_audio url_to_audio
) )
dataset = iter(
DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2)
)
model = WhisperModel("large-v3", device="cuda") model = WhisperModel("large-v3", device="cuda")
pipeline = BatchedInferencePipeline(model, device="cuda") pipeline = BatchedInferencePipeline(model, device="cuda")
@@ -59,7 +56,9 @@ pipeline = BatchedInferencePipeline(model, device="cuda")
all_transcriptions = [] all_transcriptions = []
all_references = [] all_references = []
# iterate over the dataset and run inference # iterate over the dataset and run inference
for i, row in tqdm(enumerate(dataset), desc="Evaluating..."): for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."):
if not row["audio"]:
continue
result, info = pipeline.transcribe( result, info = pipeline.transcribe(
row["audio"][0], row["audio"][0],
batch_size=8, batch_size=8,
@@ -77,7 +76,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
all_references = [normalizer(reference) for reference in all_references] all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric # compute the WER metric
wer = 100 * wer_metric.compute( word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
predictions=all_transcriptions, references=all_references print("WER: %.3f" % word_error_rate)
)
print("WER: %.3f" % wer)

View File

@@ -1,6 +1,5 @@
transformers transformers
jiwer jiwer
evaluate
datasets datasets
memory_profiler memory_profiler
py3nvml py3nvml

View File

@@ -3,7 +3,7 @@ import json
import os import os
from datasets import load_dataset from datasets import load_dataset
from evaluate import load from jiwer import wer
from tqdm import tqdm from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
@@ -25,9 +25,6 @@ model = WhisperModel(model_path, device="cuda")
# load the dataset with streaming mode # load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True) dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f)) normalizer = EnglishTextNormalizer(json.load(f))
@@ -58,7 +55,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
all_references = [normalizer(reference) for reference in all_references] all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric # compute the WER metric
wer = 100 * wer_metric.compute( word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
predictions=all_transcriptions, references=all_references print("WER: %.3f" % word_error_rate)
)
print("WER: %.3f" % wer)