Files
tinygrad/test/unit/test_llm_server.py
George Hotz d59e6e7a37 move more tests to test/null, split some existing ones (#14512)
* move more tests to test/null, split some existing ones

* null work

* null work

* move more

* fixes

* move PIL

* PIL in CLIP

* don't move that
2026-02-03 20:20:20 +08:00

30 lines
1.3 KiB
Python

import unittest
from unittest.mock import patch
from tinygrad import Tensor
class TestTransformerGenerate(unittest.TestCase):
def test_start_pos_parameter_is_used(self):
"""Test that start_pos parameter is not ignored (regression test for always resetting to 0)."""
from tinygrad.apps.llm import Transformer
# Create a minimal transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
captured_inputs = []
def mock_call(self, tokens, start_pos):
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.bind_val))
return Tensor([[42]]) # return a fake next token
with patch.object(Transformer, '__call__', mock_call):
tokens = [1, 2, 3, 4, 5]
gen = model.generate(tokens, start_pos=3)
next(gen) # get first token
# With start_pos=3, the initial tensor should only have tokens[3:] = [4, 5] (length 2)
# If the bug existed (start_pos always reset to 0), it would have all 5 tokens
self.assertEqual(captured_inputs[0][0][-1], 2) # shape should be (1, 2)
self.assertEqual(captured_inputs[0][1], 3) # start_pos should be 3, not 0
if __name__ == '__main__':
unittest.main()