mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -94,7 +94,7 @@ class AudioEncoder:
|
||||
class TextDecoder:
|
||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
|
||||
self.max_self_attn_cache_len = n_text_ctx
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
@@ -104,7 +104,7 @@ class TextDecoder:
|
||||
self.getjitted = collections.defaultdict(lambda: TinyJit(self.forward))
|
||||
|
||||
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor):
|
||||
pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos) if pos else 0
|
||||
pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len-1).bind(pos) if pos else 0
|
||||
return self.getjitted[x.shape](x, pos, encoded_audio)
|
||||
|
||||
def forward(self, x:Tensor, pos:Union[Variable, Literal[0]], encoded_audio:Tensor):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
||||
from tinygrad.helpers import CI, fetch, Context
|
||||
from tinygrad.helpers import CI, fetch
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -24,25 +24,15 @@ class TestWhisper(unittest.TestCase):
|
||||
model, enc = init_whisper("tiny.en", batch_size=2)
|
||||
cls.model = model
|
||||
cls.enc = enc
|
||||
# TODO: whisper has out of bounds access somewhere
|
||||
cls.context = Context(IGNORE_OOB=1)
|
||||
cls.context.__enter__()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.context.__exit__(None, None, None)
|
||||
del cls.model
|
||||
del cls.enc
|
||||
|
||||
def test_transcribe_file1(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
|
||||
|
||||
@unittest.expectedFailure # Test for out of bounds access
|
||||
@unittest.skip("TODO: flaky")
|
||||
def test_transcribe_file1_OOB(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
|
||||
|
||||
@unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI")
|
||||
def test_transcribe_file2(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)
|
||||
|
||||
Reference in New Issue
Block a user