mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-08 13:14:00 -05:00
use jiwer instead of evaluate in benchmarks (#1159)
This commit is contained in:
@@ -5,9 +5,9 @@ import os
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from evaluate import load
|
from jiwer import wer
|
||||||
from pytubefix import YouTube
|
from pytubefix import YouTube
|
||||||
from torch.utils.data import DataLoader
|
from pytubefix.exceptions import VideoUnavailable
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
||||||
|
|
||||||
@@ -17,15 +17,19 @@ from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
|
|||||||
def url_to_audio(row):
|
def url_to_audio(row):
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
yt = YouTube(row["link"])
|
yt = YouTube(row["link"])
|
||||||
video = (
|
try:
|
||||||
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
|
video = (
|
||||||
.order_by("bitrate")
|
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
|
||||||
.desc()
|
.order_by("bitrate")
|
||||||
.first()
|
.desc()
|
||||||
)
|
.last()
|
||||||
video.stream_to_buffer(buffer)
|
)
|
||||||
buffer.seek(0)
|
video.stream_to_buffer(buffer)
|
||||||
row["audio"] = decode_audio(buffer)
|
buffer.seek(0)
|
||||||
|
row["audio"] = decode_audio(buffer)
|
||||||
|
except VideoUnavailable:
|
||||||
|
print(f'Failed to download: {row["link"]}')
|
||||||
|
row["audio"] = []
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
|
||||||
@@ -39,19 +43,12 @@ parser.add_argument(
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# define the evaluation metric
|
|
||||||
wer_metric = load("wer")
|
|
||||||
|
|
||||||
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
|
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
|
||||||
normalizer = EnglishTextNormalizer(json.load(f))
|
normalizer = EnglishTextNormalizer(json.load(f))
|
||||||
|
|
||||||
dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
|
dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
|
||||||
url_to_audio
|
url_to_audio
|
||||||
)
|
)
|
||||||
dataset = iter(
|
|
||||||
DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2)
|
|
||||||
)
|
|
||||||
|
|
||||||
model = WhisperModel("large-v3", device="cuda")
|
model = WhisperModel("large-v3", device="cuda")
|
||||||
pipeline = BatchedInferencePipeline(model, device="cuda")
|
pipeline = BatchedInferencePipeline(model, device="cuda")
|
||||||
|
|
||||||
@@ -59,7 +56,9 @@ pipeline = BatchedInferencePipeline(model, device="cuda")
|
|||||||
all_transcriptions = []
|
all_transcriptions = []
|
||||||
all_references = []
|
all_references = []
|
||||||
# iterate over the dataset and run inference
|
# iterate over the dataset and run inference
|
||||||
for i, row in tqdm(enumerate(dataset), desc="Evaluating..."):
|
for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."):
|
||||||
|
if not row["audio"]:
|
||||||
|
continue
|
||||||
result, info = pipeline.transcribe(
|
result, info = pipeline.transcribe(
|
||||||
row["audio"][0],
|
row["audio"][0],
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
@@ -77,7 +76,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
|
|||||||
all_references = [normalizer(reference) for reference in all_references]
|
all_references = [normalizer(reference) for reference in all_references]
|
||||||
|
|
||||||
# compute the WER metric
|
# compute the WER metric
|
||||||
wer = 100 * wer_metric.compute(
|
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
|
||||||
predictions=all_transcriptions, references=all_references
|
print("WER: %.3f" % word_error_rate)
|
||||||
)
|
|
||||||
print("WER: %.3f" % wer)
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
transformers
|
transformers
|
||||||
jiwer
|
jiwer
|
||||||
evaluate
|
|
||||||
datasets
|
datasets
|
||||||
memory_profiler
|
memory_profiler
|
||||||
py3nvml
|
py3nvml
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from evaluate import load
|
from jiwer import wer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
||||||
|
|
||||||
@@ -25,9 +25,6 @@ model = WhisperModel(model_path, device="cuda")
|
|||||||
# load the dataset with streaming mode
|
# load the dataset with streaming mode
|
||||||
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
|
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
|
||||||
|
|
||||||
# define the evaluation metric
|
|
||||||
wer_metric = load("wer")
|
|
||||||
|
|
||||||
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
|
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
|
||||||
normalizer = EnglishTextNormalizer(json.load(f))
|
normalizer = EnglishTextNormalizer(json.load(f))
|
||||||
|
|
||||||
@@ -58,7 +55,5 @@ all_transcriptions = [normalizer(transcription) for transcription in all_transcr
|
|||||||
all_references = [normalizer(reference) for reference in all_references]
|
all_references = [normalizer(reference) for reference in all_references]
|
||||||
|
|
||||||
# compute the WER metric
|
# compute the WER metric
|
||||||
wer = 100 * wer_metric.compute(
|
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
|
||||||
predictions=all_transcriptions, references=all_references
|
print("WER: %.3f" % word_error_rate)
|
||||||
)
|
|
||||||
print("WER: %.3f" % wer)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user