test_assign_kv_cache

This commit is contained in:
George Hotz
2024-03-14 16:17:20 -07:00
parent 38ba277ac8
commit d52d0b0efb

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad import dtypes, TinyJit, GlobalCounters
from tinygrad import dtypes, TinyJit, GlobalCounters, Variable
N = 200 # has to be bigger than the cache to fail
@@ -69,6 +69,30 @@ class TestAssign(unittest.TestCase):
for _ in range(4): f(y)
assert y.item() == 4
def test_assign_kv_cache(self):
bsz, max_context = 2, 8
class Attn:
@TinyJit
def __call__(self, xk:Tensor, start_pos:Variable):
seqlen = xk.shape[1]
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).contiguous()
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
attn = Attn()
xk = Tensor.ones(bsz, 3, 1, 1).contiguous()
attn(xk, 0)
for i in range(3,6):
# copied from LLaMA
start_pos = Variable("start_pos", 1, max_context).bind(i)
xk = Tensor.ones(bsz, 1, 1, 1).contiguous()
attn(xk, start_pos)
out = attn.cache_k.flatten().numpy()
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
def test_assign_contiguous(self):
b = Tensor.rand(4,4).realize()
a = (Tensor.rand(4,4).realize() + 1)