simple symbolic slice in llama [pr] (#10112)

support slice that has step None and stop > start
This commit is contained in:
chenyu
2025-04-30 14:36:35 -04:00
committed by GitHub
parent b583ece8f3
commit 17d4d258ea
4 changed files with 38 additions and 11 deletions

View File

@@ -77,10 +77,10 @@ class Attention:
# update the cache
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
keys = self.cache_kv[0, :, 0:start_pos+seqlen, :, :]
values = self.cache_kv[1, :, 0:start_pos+seqlen, :, :]
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -176,7 +176,7 @@ class Transformer:
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize()
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen, :, :, :]
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)

View File

@@ -184,6 +184,19 @@ class TestSymbolicJit(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_slice(self):
# slice is a movement, so we pair it with a simple function to test the JIT interaction
def f(a): return (a+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a[3:5, vi:vi+2]
symbolic = jf(symbolic).numpy()
expected = f(a[3:5, i:i+2]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_ones_sum(self):
def f(a): return a.sum().realize()
jf = TinyJit(f)

View File

@@ -164,6 +164,15 @@ class TestSymbolicOps(unittest.TestCase):
expected = a.shrink(((3,5),(i,i+2))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_slice(self):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a[3:5, vi:vi+2]
symbolic = symbolic.numpy()
expected = a[3:5, i:i+2].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_ones_sum(self):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)

View File

@@ -1118,13 +1118,18 @@ class Tensor(SimpleMathTrait):
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
case slice():
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
# handle int slicing
*boundary, stride = index.indices(cast(SupportsIndex, size))
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
# update size for slice
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
if all(isinstance(s, int) or s is None for s in (index.start,index.stop,index.step)):
# handle int slicing
*boundary, stride = index.indices(cast(SupportsIndex, size))
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
# update size for slice
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
elif index.step is None and all(isinstance(s,(int,UOp))for s in (index.start,index.stop)) and resolve((index.stop-index.start) > 0, False):
# simple symbolic slice
boundary = [index.start, index.stop]
size = (index.stop - index.start).ssimplify()
else: raise TypeError(f"slice {index=} is not supported")
case None: pass # do nothing
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})