From 92c0675ccf4dcb7c673cca096fc3e1fd804d69ec Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 7 Apr 2024 20:35:22 -0400 Subject: [PATCH] setitem initial support (#4093) * wip setitem it's an eager assign to output shapetracker view * cleanups and tests * more cleanups --- examples/gpt2.py | 11 +++++----- extra/models/llama.py | 12 +++++----- test/test_setitem.py | 44 +++++++++++++++++++++++++++++++++++++ tinygrad/engine/schedule.py | 5 ++++- tinygrad/lazy.py | 4 ++-- tinygrad/tensor.py | 8 +++++-- 6 files changed, 68 insertions(+), 16 deletions(-) create mode 100644 test/test_setitem.py diff --git a/examples/gpt2.py b/examples/gpt2.py index f74b07627a..e9a18c67c3 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -34,17 +34,16 @@ class Attention: if not hasattr(self, "cache_kv"): self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize() + # update the cache + self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack([xk, xv])).realize() + if start_pos > 0: - keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1) - values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1) + keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) + values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) else: keys = xk values = xv - # update the cache - new_cache = Tensor.stack([keys, values]).pad((None, None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)) - self.cache_kv.assign(new_cache).realize() - xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim)) diff --git a/extra/models/llama.py b/extra/models/llama.py index 2028a3f738..cb18db7a97 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -73,13 +73,15 @@ class Attention: self.cache_v.shard_((xv.device), axis=None).realize() # HACK: without contiguous, the conversation mode is broken and the cache is not updated - keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk - values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1).contiguous() if start_pos > 0 else xv # update the cache - assert keys.dtype == self.cache_k.dtype and values.dtype == self.cache_v.dtype, f"{keys.dtype=}, {values.dtype=}, {self.cache_k.dtype=}, {self.cache_v.dtype=}" - self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None))).realize() - self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None))).realize() + assert xk.dtype == self.cache_k.dtype and xv.dtype == self.cache_v.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_k.dtype=}, {self.cache_v.dtype=}" + self.cache_k.shrink((None, (start_pos, start_pos+seqlen), None, None)).assign(xk.contiguous()).realize() + self.cache_v.shrink((None, (start_pos, start_pos+seqlen), None, None)).assign(xv.contiguous()).realize() + + keys = self.cache_k.shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk + values = self.cache_v.shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv + keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2) diff --git a/test/test_setitem.py b/test/test_setitem.py new file mode 100644 index 0000000000..481f8cad3f --- /dev/null +++ b/test/test_setitem.py @@ -0,0 +1,44 @@ +import unittest +from tinygrad import Tensor, TinyJit, Variable, dtypes +import numpy as np + +class TestSetitem(unittest.TestCase): + def test_simple_setitem(self): + t = Tensor.zeros(6, 6).contiguous().realize() + t[2:4, 3:5] = Tensor.ones(2, 2) + n = np.zeros((6, 6)) + n[2:4, 3:5] = np.ones((2, 2)) + np.testing.assert_allclose(t.numpy(), n) + + def test_simple_jit_setitem(self): + @TinyJit + def f(t:Tensor, a:Tensor): + t[2:4, 3:5] = a + + for i in range(1, 6): + t = Tensor.zeros(6, 6).contiguous().realize() + a = Tensor.full((2, 2), fill_value=i, dtype=dtypes.float).contiguous() + f(t, a) + + n = np.zeros((6, 6)) + n[2:4, 3:5] = np.full((2, 2), i) + np.testing.assert_allclose(t.numpy(), n) + + def test_jit_setitem_variable_offset(self): + @TinyJit + def f(t:Tensor, a:Tensor, v:Variable): + t.shrink(((v,v+1), None)).assign(a).realize() + + t = Tensor.zeros(6, 6).contiguous().realize() + n = np.zeros((6, 6)) + + for i in range(6): + v = Variable("v", 0, 6).bind(i) + a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous() + n[i, :] = i+1 + f(t, a, v) + np.testing.assert_allclose(t.numpy(), n) + np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]]) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 68e62b5108..cb4b8eecba 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -79,8 +79,11 @@ def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[ op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs) else: output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out] + output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={}) - op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:] + output_view, vv = output_view.simplify().unbind() + if vv: var_vals.update(vv) + op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view)), membufs[1:] return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals) # recursively search the entire graph for all LazyBuffers, insert realizes after expands diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index d5b994f844..9a67ff97b5 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -64,8 +64,8 @@ class LazyBuffer: def is_realized(self) -> bool: return self.base.realized is not None def assign(self, x:LazyBuffer) -> LazyBuffer: - assert (self.base is self) or (self.st.contiguous and self.size == self.base.size), f"assign target must be contiguous {self.st}" - return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self.base)) + assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}" + return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base)) def contiguous(self): if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const(): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 26eca8ed42..259a73cb8d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -510,9 +510,13 @@ class Tensor: ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:]) return ret - def __setitem__(self,indices,v): + def __setitem__(self, indices, v:Tensor): if isinstance(self.device, str) and self.device.startswith("DISK"): return self.__getitem__(indices).assign(v) - raise NotImplementedError("not implemented yet") + # TODO: support python const v + # TODO: broadcast v to the shape here, refactor for const v and one way broadcast_shape + assign_to = self.__getitem__(indices) + # NOTE: contiguous to prevent const folding. + return assign_to.assign(v._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()).realize() # NOTE: using slice is discouraged and things should migrate to pad and shrink def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: