mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* whisper: support batch inference, add librispeech WER test, add kv caching and JIT * remove JIT_SUPPORTED_DEVICE --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
37 lines
1.6 KiB
Python
37 lines
1.6 KiB
Python
import unittest
|
|
import pathlib
|
|
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
|
from tinygrad.ops import Device
|
|
|
|
@unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array")
|
|
class TestWhisper(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
model, enc = init_whisper("tiny.en", batch_size=2)
|
|
cls.model = model
|
|
cls.enc = enc
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.model
|
|
del cls.enc
|
|
|
|
def test_transcribe_file(self):
|
|
# Audio generated with the command on MacOS:
|
|
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
|
|
# We use the WAVE type because it's easier to decode in CI test environments
|
|
filename = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
|
transcription = transcribe_file(self.model, self.enc, filename)
|
|
self.assertEqual("Could you please let me out of the box?", transcription)
|
|
|
|
def test_transcribe_batch(self):
|
|
file1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
|
file2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
|
|
|
|
waveforms = [load_file_waveform(file1), load_file_waveform(file2)]
|
|
|
|
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
|
|
self.assertEqual(2, len(transcriptions))
|
|
self.assertEqual("Could you please let me out of the box?", transcriptions[0])
|
|
self.assertEqual("a slightly longer audio file so that we can test batch transcriptions of varying length.", transcriptions[1])
|