mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simple symbolic slice in llama [pr] (#10112)
support slice that has step None and stop > start
This commit is contained in:
@@ -77,10 +77,10 @@ class Attention:
|
||||
|
||||
# update the cache
|
||||
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
|
||||
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
|
||||
self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
|
||||
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
|
||||
keys = self.cache_kv[0, :, 0:start_pos+seqlen, :, :]
|
||||
values = self.cache_kv[1, :, 0:start_pos+seqlen, :, :]
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
@@ -176,7 +176,7 @@ class Transformer:
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize()
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen, :, :, :]
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
|
||||
@@ -184,6 +184,19 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_slice(self):
|
||||
# slice is a movement, so we pair it with a simple function to test the JIT interaction
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a[3:5, vi:vi+2]
|
||||
symbolic = jf(symbolic).numpy()
|
||||
expected = f(a[3:5, i:i+2]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_ones_sum(self):
|
||||
def f(a): return a.sum().realize()
|
||||
jf = TinyJit(f)
|
||||
|
||||
@@ -164,6 +164,15 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
expected = a.shrink(((3,5),(i,i+2))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_slice(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a[3:5, vi:vi+2]
|
||||
symbolic = symbolic.numpy()
|
||||
expected = a[3:5, i:i+2].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_ones_sum(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
|
||||
@@ -1118,13 +1118,18 @@ class Tensor(SimpleMathTrait):
|
||||
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
||||
case slice():
|
||||
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
|
||||
if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
|
||||
# handle int slicing
|
||||
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
||||
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
||||
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
|
||||
# update size for slice
|
||||
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
|
||||
if all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)):
|
||||
# handle int slicing
|
||||
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
||||
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
||||
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
|
||||
# update size for slice
|
||||
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
|
||||
elif index.step is None and all(isinstance(s,(int,UOp))for s in (index.start,index.stop)) and resolve((index.stop-index.start) > 0, False):
|
||||
# simple symbolic slice
|
||||
boundary = [index.start, index.stop]
|
||||
size = (index.stop - index.start).ssimplify()
|
||||
else: raise TypeError(f"slice {index=} is not supported")
|
||||
case None: pass # do nothing
|
||||
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
|
||||
indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
|
||||
|
||||
Reference in New Issue
Block a user