Fix whisper OOB (#10685)

* fix whisper and test

* remove import
This commit is contained in:
Sieds Lykles
2025-06-08 02:23:50 +02:00
committed by GitHub
parent 53ed64e133
commit c29a56dd51
2 changed files with 3 additions and 13 deletions

View File

@@ -94,7 +94,7 @@ class AudioEncoder:
class TextDecoder:
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
self.max_tokens_to_sample = n_text_ctx // 2
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
self.max_self_attn_cache_len = n_text_ctx
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
@@ -104,7 +104,7 @@ class TextDecoder:
self.getjitted = collections.defaultdict(lambda: TinyJit(self.forward))
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor):
pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos) if pos else 0
pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len-1).bind(pos) if pos else 0
return self.getjitted[x.shape](x, pos, encoded_audio)
def forward(self, x:Tensor, pos:Union[Variable, Literal[0]], encoded_audio:Tensor):

View File

@@ -1,7 +1,7 @@
import unittest
import pathlib
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
from tinygrad.helpers import CI, fetch, Context
from tinygrad.helpers import CI, fetch
from tinygrad import Device, dtypes
from tinygrad.device import is_dtype_supported
@@ -24,25 +24,15 @@ class TestWhisper(unittest.TestCase):
model, enc = init_whisper("tiny.en", batch_size=2)
cls.model = model
cls.enc = enc
# TODO: whisper has out of bounds access somewhere
cls.context = Context(IGNORE_OOB=1)
cls.context.__enter__()
@classmethod
def tearDownClass(cls):
cls.context.__exit__(None, None, None)
del cls.model
del cls.enc
def test_transcribe_file1(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
@unittest.expectedFailure # Test for out of bounds access
@unittest.skip("TODO: flaky")
def test_transcribe_file1_OOB(self):
with Context(IGNORE_OOB=0):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
@unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI")
def test_transcribe_file2(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)