setitem initial support (#4093)

* wip setitem

it's an eager assign to output shapetracker view

* cleanups and tests

* more cleanups
This commit is contained in:
chenyu
2024-04-07 20:35:22 -04:00
committed by GitHub
parent 183708b3fd
commit 92c0675ccf
6 changed files with 68 additions and 16 deletions

View File

@@ -34,17 +34,16 @@ class Attention:
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
# update the cache
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack([xk, xv])).realize()
if start_pos > 0:
keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
else:
keys = xk
values = xv
# update the cache
new_cache = Tensor.stack([keys, values]).pad((None, None,(0,MAX_CONTEXT-start_pos-seqlen),None,None))
self.cache_kv.assign(new_cache).realize()
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))

View File

@@ -73,13 +73,15 @@ class Attention:
self.cache_v.shard_((xv.device), axis=None).realize()
# HACK: without contiguous, the conversation mode is broken and the cache is not updated
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1).contiguous() if start_pos > 0 else xv
# update the cache
assert keys.dtype == self.cache_k.dtype and values.dtype == self.cache_v.dtype, f"{keys.dtype=}, {values.dtype=}, {self.cache_k.dtype=}, {self.cache_v.dtype=}"
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None))).realize()
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None))).realize()
assert xk.dtype == self.cache_k.dtype and xv.dtype == self.cache_v.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_k.dtype=}, {self.cache_v.dtype=}"
self.cache_k.shrink((None, (start_pos, start_pos+seqlen), None, None)).assign(xk.contiguous()).realize()
self.cache_v.shrink((None, (start_pos, start_pos+seqlen), None, None)).assign(xv.contiguous()).realize()
keys = self.cache_k.shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
values = self.cache_v.shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)

44
test/test_setitem.py Normal file
View File

@@ -0,0 +1,44 @@
import unittest
from tinygrad import Tensor, TinyJit, Variable, dtypes
import numpy as np
class TestSetitem(unittest.TestCase):
def test_simple_setitem(self):
t = Tensor.zeros(6, 6).contiguous().realize()
t[2:4, 3:5] = Tensor.ones(2, 2)
n = np.zeros((6, 6))
n[2:4, 3:5] = np.ones((2, 2))
np.testing.assert_allclose(t.numpy(), n)
def test_simple_jit_setitem(self):
@TinyJit
def f(t:Tensor, a:Tensor):
t[2:4, 3:5] = a
for i in range(1, 6):
t = Tensor.zeros(6, 6).contiguous().realize()
a = Tensor.full((2, 2), fill_value=i, dtype=dtypes.float).contiguous()
f(t, a)
n = np.zeros((6, 6))
n[2:4, 3:5] = np.full((2, 2), i)
np.testing.assert_allclose(t.numpy(), n)
def test_jit_setitem_variable_offset(self):
@TinyJit
def f(t:Tensor, a:Tensor, v:Variable):
t.shrink(((v,v+1), None)).assign(a).realize()
t = Tensor.zeros(6, 6).contiguous().realize()
n = np.zeros((6, 6))
for i in range(6):
v = Variable("v", 0, 6).bind(i)
a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous()
n[i, :] = i+1
f(t, a, v)
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
if __name__ == '__main__':
unittest.main()

View File

@@ -79,8 +79,11 @@ def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view)), membufs[1:]
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
# recursively search the entire graph for all LazyBuffers, insert realizes after expands

View File

@@ -64,8 +64,8 @@ class LazyBuffer:
def is_realized(self) -> bool: return self.base.realized is not None
def assign(self, x:LazyBuffer) -> LazyBuffer:
assert (self.base is self) or (self.st.contiguous and self.size == self.base.size), f"assign target must be contiguous {self.st}"
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self.base))
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
def contiguous(self):
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():

View File

@@ -510,9 +510,13 @@ class Tensor:
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
return ret
def __setitem__(self,indices,v):
def __setitem__(self, indices, v:Tensor):
if isinstance(self.device, str) and self.device.startswith("DISK"): return self.__getitem__(indices).assign(v)
raise NotImplementedError("not implemented yet")
# TODO: support python const v
# TODO: broadcast v to the shape here, refactor for const v and one way broadcast_shape
assign_to = self.__getitem__(indices)
# NOTE: contiguous to prevent const folding.
return assign_to.assign(v._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()).realize()
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: