mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test_assign_kv_cache
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import dtypes, TinyJit, GlobalCounters
|
||||
from tinygrad import dtypes, TinyJit, GlobalCounters, Variable
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
@@ -69,6 +69,30 @@ class TestAssign(unittest.TestCase):
|
||||
for _ in range(4): f(y)
|
||||
assert y.item() == 4
|
||||
|
||||
def test_assign_kv_cache(self):
|
||||
bsz, max_context = 2, 8
|
||||
|
||||
class Attn:
|
||||
@TinyJit
|
||||
def __call__(self, xk:Tensor, start_pos:Variable):
|
||||
seqlen = xk.shape[1]
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).contiguous()
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
||||
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
attn = Attn()
|
||||
xk = Tensor.ones(bsz, 3, 1, 1).contiguous()
|
||||
attn(xk, 0)
|
||||
for i in range(3,6):
|
||||
# copied from LLaMA
|
||||
start_pos = Variable("start_pos", 1, max_context).bind(i)
|
||||
xk = Tensor.ones(bsz, 1, 1, 1).contiguous()
|
||||
attn(xk, start_pos)
|
||||
|
||||
out = attn.cache_k.flatten().numpy()
|
||||
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
|
||||
|
||||
def test_assign_contiguous(self):
|
||||
b = Tensor.rand(4,4).realize()
|
||||
a = (Tensor.rand(4,4).realize() + 1)
|
||||
|
||||
Reference in New Issue
Block a user