enable whisper batch for long sequences (#6458)

* long batch +test

* long batch +test

* cleanup

* rollback syntactic changes

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
kormann
2024-09-17 06:42:10 +02:00
committed by GitHub
parent 7c942418a1
commit f5dd25d376
2 changed files with 29 additions and 35 deletions

View File

@@ -61,8 +61,11 @@ class TestWhisper(unittest.TestCase):
@unittest.skipIf(CI, "too long for CI")
def test_transcribe_long_no_batch(self):
waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)]
with self.assertRaises(Exception):
transcribe_waveform(self.model, self.enc, waveforms)
trancriptions = transcribe_waveform(self.model, self.enc, waveforms)
self.assertEqual(2, len(trancriptions))
self.assertEqual(TRANSCRIPTION_3, trancriptions[0])
self.assertEqual(TRANSCRIPTION_1, trancriptions[1])
if __name__ == '__main__':
unittest.main()