mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
whisper: support batch inference, add librispeech WER test (#2074)
* whisper: support batch inference, add librispeech WER test, add kv caching and JIT * remove JIT_SUPPORTED_DEVICE --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
83
test/external/external_test_whisper_librispeech.py
vendored
Normal file
83
test/external/external_test_whisper_librispeech.py
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
import unittest
|
||||
import torch
|
||||
import tqdm
|
||||
import torchaudio
|
||||
import pathlib
|
||||
import jiwer
|
||||
import os
|
||||
import numpy as np
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
from examples.whisper import init_whisper, transcribe_waveform
|
||||
|
||||
class TestWhisperLibriSpeech(unittest.TestCase):
|
||||
# reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
|
||||
# the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
|
||||
# tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
|
||||
def test_en_tiny(self):
|
||||
run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
|
||||
|
||||
def test_tiny(self):
|
||||
run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
|
||||
|
||||
def test_en_base(self):
|
||||
run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
|
||||
|
||||
def test_en_small(self):
|
||||
run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
|
||||
|
||||
def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
|
||||
dataset = LibriSpeech()
|
||||
batch_size=16
|
||||
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
model, enc = init_whisper(model_name, batch_size=batch_size)
|
||||
|
||||
hypotheses = []
|
||||
references = []
|
||||
|
||||
for audio, texts in tqdm.tqdm(loader):
|
||||
transcriptions = transcribe_waveform(model, enc, audio.numpy())
|
||||
hypotheses.extend(transcriptions)
|
||||
references.extend(texts)
|
||||
|
||||
normalizer = EnglishTextNormalizer()
|
||||
normalized_hypotheses = [normalizer(text) for text in hypotheses]
|
||||
normalized_references = [normalizer(text) for text in references]
|
||||
wer = jiwer.wer(normalized_hypotheses, normalized_references)
|
||||
|
||||
np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
|
||||
print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
|
||||
del model, enc
|
||||
|
||||
class LibriSpeech(torch.utils.data.Dataset):
|
||||
def __init__(self):
|
||||
dir = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
|
||||
self.dataset = torchaudio.datasets.LIBRISPEECH(
|
||||
root=dir,
|
||||
url="test-clean",
|
||||
download=True,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
audio, sample_rate, text, _, _, _ = self.dataset[item]
|
||||
assert sample_rate == 16000
|
||||
return pad_or_trim_tensor(audio[0]), text
|
||||
|
||||
def pad_or_trim_tensor(tensor, target_len=480000):
|
||||
curr_len = len(tensor)
|
||||
if curr_len == target_len:
|
||||
return tensor
|
||||
elif curr_len < target_len:
|
||||
return torch.cat((tensor, torch.zeros(target_len - curr_len)))
|
||||
else:
|
||||
return tensor[:target_len]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user