mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Whisper less flaky tests (#13435)
* use less flaky metric for whisper long transcription * multiline long transcription 3 reference * fix reference transcript see https://homepage.ntu.edu.tw/~karchung/miniconversations/MC.htm sanitized for whisper * try lower wer threshold * add test for wer metric * extract TRANSCRIPTION_3_ALT * rename test * rename * add tests for high WER difference * move tests * sync metric
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
||||
import examples.mlperf.metrics as metrics
|
||||
from tinygrad.helpers import CI, fetch, CPU_LLVM
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -14,7 +15,39 @@ TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
|
||||
TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length."
|
||||
# TODO this file will possibly not survive long. find another 1-2 minute sound file online to transcribe
|
||||
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
|
||||
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." # noqa: E501
|
||||
TRANSCRIPTION_3 = """Just lie back and relax.
|
||||
Is the level of pressure about right?
|
||||
Yes, it's fine. And I'd like conditioner, please.
|
||||
Sure. I'm going to start the second lathering now.
|
||||
Would you like some Q-tips?
|
||||
How'd you like it cut?
|
||||
I'd like my bangs and the back trimmed,
|
||||
and I'd like the rest thinned out a bit and layered.
|
||||
Where would you like the part?
|
||||
On the left, right about here.
|
||||
Here, have a look. What do you think?
|
||||
It's fine. Here's thousand NT dollars.
|
||||
It's 30 NT extra for the rinse. Here's your change and receipt.
|
||||
Thank you, and please come again!
|
||||
So, how do you like it?
|
||||
It could have been worse. But you'll notice that I didn't ask her for her card.
|
||||
Hmm, yeah.
|
||||
Mm, maybe you can try that place over there next time."""
|
||||
|
||||
TRANSCRIPTION_3_ALT = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine. And I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs on the back trimmed, and I'd like the rest to stand out a bit and layered. Where would you like the part? On the left, right about here. Here. Have a look. What do you think? It's fine. Here's a thousand and eighty dollars. It's thirty and t extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." #noqa: E501
|
||||
# NOTE: same as TRANSCRIPTION_3 but with minor changes that should only amount to ~0.079 WER difference (see test_wer_same)
|
||||
# 'and' --> 'on'
|
||||
# 'thinned' --> 'to stand'
|
||||
# 'nt' --> 'and eighty'
|
||||
# '30 nt' --> 'thirty and t'
|
||||
# 'rinse' --> 'rants'
|
||||
# 'mm' --> ''
|
||||
|
||||
def wer_helper(result: str, reference: str)->float:
|
||||
result = metrics.normalize_string(result)
|
||||
reference = metrics.normalize_string(reference)
|
||||
wer, _, _ = metrics.word_error_rate([result], [reference])
|
||||
return wer
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["CPU"], "slow")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
@@ -30,6 +63,15 @@ class TestWhisper(unittest.TestCase):
|
||||
del cls.model
|
||||
del cls.enc
|
||||
|
||||
def assertWER(self, actual: str, expected: str, threshold: float):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
wer = wer_helper(actual, expected)
|
||||
if wer > threshold:
|
||||
err = f"WER={wer:.3f} > {threshold}"
|
||||
raise AssertionError(
|
||||
err
|
||||
)
|
||||
|
||||
def test_transcribe_file1(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
|
||||
|
||||
@@ -56,7 +98,7 @@ class TestWhisper(unittest.TestCase):
|
||||
def test_transcribe_long(self):
|
||||
waveform = [load_file_waveform(fetch(TEST_FILE_3_URL))]
|
||||
transcription = transcribe_waveform(self.model, self.enc, waveform)
|
||||
self.assertEqual(TRANSCRIPTION_3, transcription)
|
||||
self.assertWER(transcription, TRANSCRIPTION_3, 0.085)
|
||||
|
||||
@unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too long for CI")
|
||||
def test_transcribe_long_no_batch(self):
|
||||
@@ -64,8 +106,24 @@ class TestWhisper(unittest.TestCase):
|
||||
|
||||
trancriptions = transcribe_waveform(self.model, self.enc, waveforms)
|
||||
self.assertEqual(2, len(trancriptions))
|
||||
self.assertEqual(TRANSCRIPTION_3, trancriptions[0])
|
||||
self.assertWER(trancriptions[0], TRANSCRIPTION_3, 0.085)
|
||||
self.assertEqual(TRANSCRIPTION_1, trancriptions[1])
|
||||
|
||||
def test_wer_same(self):
|
||||
reference = TRANSCRIPTION_3
|
||||
self.assertWER(TRANSCRIPTION_3_ALT, reference, 0.079)
|
||||
|
||||
def test_wer_different(self):
|
||||
reference = TRANSCRIPTION_3
|
||||
self.assertWER("[no speech]", reference, 1.0)
|
||||
|
||||
def test_wer_different_2(self):
|
||||
reference = TRANSCRIPTION_3
|
||||
self.assertWER("", reference, 1.0)
|
||||
|
||||
def test_wer_different_3(self):
|
||||
reference = TRANSCRIPTION_3
|
||||
self.assertWER(reference[:len(reference)//2], reference, 0.524)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user