diff --git a/benchmark/evaluate_yt_commons.py b/benchmark/evaluate_yt_commons.py index 0511be6..cbdce4f 100644 --- a/benchmark/evaluate_yt_commons.py +++ b/benchmark/evaluate_yt_commons.py @@ -5,9 +5,9 @@ import os from io import BytesIO from datasets import load_dataset -from evaluate import load +from jiwer import wer from pytubefix import YouTube -from torch.utils.data import DataLoader +from pytubefix.exceptions import VideoUnavailable from tqdm import tqdm from transformers.models.whisper.english_normalizer import EnglishTextNormalizer @@ -17,15 +17,19 @@ from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio def url_to_audio(row): buffer = BytesIO() yt = YouTube(row["link"]) - video = ( - yt.streams.filter(only_audio=True, mime_type="audio/mp4") - .order_by("bitrate") - .desc() - .first() - ) - video.stream_to_buffer(buffer) - buffer.seek(0) - row["audio"] = decode_audio(buffer) + try: + video = ( + yt.streams.filter(only_audio=True, mime_type="audio/mp4") + .order_by("bitrate") + .desc() + .last() + ) + video.stream_to_buffer(buffer) + buffer.seek(0) + row["audio"] = decode_audio(buffer) + except VideoUnavailable: + print(f'Failed to download: {row["link"]}') + row["audio"] = [] return row @@ -39,19 +43,12 @@ parser.add_argument( ) args = parser.parse_args() -# define the evaluation metric -wer_metric = load("wer") - with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: normalizer = EnglishTextNormalizer(json.load(f)) dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map( url_to_audio ) -dataset = iter( - DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2) -) - model = WhisperModel("large-v3", device="cuda") pipeline = BatchedInferencePipeline(model, device="cuda") @@ -59,7 +56,9 @@ pipeline = BatchedInferencePipeline(model, device="cuda") all_transcriptions = [] all_references = [] # iterate over the dataset and run inference -for i, row in tqdm(enumerate(dataset), desc="Evaluating..."): +for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."): + if not row["audio"]: + continue result, info = pipeline.transcribe( row["audio"][0], batch_size=8, @@ -77,7 +76,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr all_references = [normalizer(reference) for reference in all_references] # compute the WER metric -wer = 100 * wer_metric.compute( - predictions=all_transcriptions, references=all_references -) -print("WER: %.3f" % wer) +word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references) +print("WER: %.3f" % word_error_rate) diff --git a/benchmark/requirements.benchmark.txt b/benchmark/requirements.benchmark.txt index c49dcca..674c23e 100644 --- a/benchmark/requirements.benchmark.txt +++ b/benchmark/requirements.benchmark.txt @@ -1,6 +1,5 @@ transformers jiwer -evaluate datasets memory_profiler py3nvml diff --git a/benchmark/wer_benchmark.py b/benchmark/wer_benchmark.py index f7a0b79..2bc1bfb 100644 --- a/benchmark/wer_benchmark.py +++ b/benchmark/wer_benchmark.py @@ -3,7 +3,7 @@ import json import os from datasets import load_dataset -from evaluate import load +from jiwer import wer from tqdm import tqdm from transformers.models.whisper.english_normalizer import EnglishTextNormalizer @@ -25,9 +25,6 @@ model = WhisperModel(model_path, device="cuda") # load the dataset with streaming mode dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True) -# define the evaluation metric -wer_metric = load("wer") - with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: normalizer = EnglishTextNormalizer(json.load(f)) @@ -58,7 +55,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr all_references = [normalizer(reference) for reference in all_references] # compute the WER metric -wer = 100 * wer_metric.compute( - predictions=all_transcriptions, references=all_references -) -print("WER: %.3f" % wer) +word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references) +print("WER: %.3f" % word_error_rate)