mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* move more tests to test/null, split some existing ones * null work * null work * move more * fixes * move PIL * PIL in CLIP * don't move that
30 lines
1.3 KiB
Python
30 lines
1.3 KiB
Python
import unittest
|
|
from unittest.mock import patch
|
|
from tinygrad import Tensor
|
|
|
|
class TestTransformerGenerate(unittest.TestCase):
|
|
def test_start_pos_parameter_is_used(self):
|
|
"""Test that start_pos parameter is not ignored (regression test for always resetting to 0)."""
|
|
from tinygrad.apps.llm import Transformer
|
|
# Create a minimal transformer
|
|
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
|
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
|
|
|
captured_inputs = []
|
|
def mock_call(self, tokens, start_pos):
|
|
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.bind_val))
|
|
return Tensor([[42]]) # return a fake next token
|
|
|
|
with patch.object(Transformer, '__call__', mock_call):
|
|
tokens = [1, 2, 3, 4, 5]
|
|
gen = model.generate(tokens, start_pos=3)
|
|
next(gen) # get first token
|
|
|
|
# With start_pos=3, the initial tensor should only have tokens[3:] = [4, 5] (length 2)
|
|
# If the bug existed (start_pos always reset to 0), it would have all 5 tokens
|
|
self.assertEqual(captured_inputs[0][0][-1], 2) # shape should be (1, 2)
|
|
self.assertEqual(captured_inputs[0][1], 3) # start_pos should be 3, not 0
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|