From 17d4d258ea3457db874388a3ddfd88126baa97ee Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 30 Apr 2025 14:36:35 -0400 Subject: [PATCH] simple symbolic slice in llama [pr] (#10112) support slice that has step None and stop > start --- extra/models/llama.py | 8 ++++---- test/test_symbolic_jit.py | 13 +++++++++++++ test/test_symbolic_ops.py | 9 +++++++++ tinygrad/tensor.py | 19 ++++++++++++------- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/extra/models/llama.py b/extra/models/llama.py index 3a089facbc..d164fa76a2 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -77,10 +77,10 @@ class Attention: # update the cache assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}" - self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize() + self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize() - keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) - values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) + keys = self.cache_kv[0, :, 0:start_pos+seqlen, :, :] + values = self.cache_kv[1, :, 0:start_pos+seqlen, :, :] keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) @@ -176,7 +176,7 @@ class Transformer: h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize() - freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None)) + freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen, :, :, :] mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 7c2d171bf1..1529e3bc21 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -184,6 +184,19 @@ class TestSymbolicJit(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) + def test_slice(self): + # slice is a movement, so we pair it with a simple function to test the JIT interaction + def f(a): return (a+1).realize() + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(7, 11) + symbolic = a[3:5, vi:vi+2] + symbolic = jf(symbolic).numpy() + expected = f(a[3:5, i:i+2]).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) + def test_ones_sum(self): def f(a): return a.sum().realize() jf = TinyJit(f) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 56bd2f6b15..55bf2c9820 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -164,6 +164,15 @@ class TestSymbolicOps(unittest.TestCase): expected = a.shrink(((3,5),(i,i+2))).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_slice(self): + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(7, 11) + symbolic = a[3:5, vi:vi+2] + symbolic = symbolic.numpy() + expected = a[3:5, i:i+2].numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_ones_sum(self): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1fc4b37ded..d625b33c6e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1118,13 +1118,18 @@ class Tensor(SimpleMathTrait): boundary = [index, index+1] if index >= 0 else [index+size, index+size+1] case slice(): if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step") - if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported") - # handle int slicing - *boundary, stride = index.indices(cast(SupportsIndex, size)) - if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] - elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1] - # update size for slice - size = ceildiv((boundary[1] - boundary[0]), abs(stride)) + if all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)): + # handle int slicing + *boundary, stride = index.indices(cast(SupportsIndex, size)) + if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] + elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1] + # update size for slice + size = ceildiv((boundary[1] - boundary[0]), abs(stride)) + elif index.step is None and all(isinstance(s,(int,UOp))for s in (index.start,index.stop)) and resolve((index.stop-index.start) > 0, False): + # simple symbolic slice + boundary = [index.start, index.stop] + size = (index.stop - index.start).ssimplify() + else: raise TypeError(f"slice {index=} is not supported") case None: pass # do nothing case _: raise IndexError(f"{type(index).__name__} indexing is not supported") indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})