mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
setitem initial support (#4093)
* wip setitem it's an eager assign to output shapetracker view * cleanups and tests * more cleanups
This commit is contained in:
@@ -34,17 +34,16 @@ class Attention:
|
||||
if not hasattr(self, "cache_kv"):
|
||||
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
||||
|
||||
# update the cache
|
||||
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack([xk, xv])).realize()
|
||||
|
||||
if start_pos > 0:
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
|
||||
else:
|
||||
keys = xk
|
||||
values = xv
|
||||
|
||||
# update the cache
|
||||
new_cache = Tensor.stack([keys, values]).pad((None, None,(0,MAX_CONTEXT-start_pos-seqlen),None,None))
|
||||
self.cache_kv.assign(new_cache).realize()
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))
|
||||
|
||||
|
||||
@@ -73,13 +73,15 @@ class Attention:
|
||||
self.cache_v.shard_((xv.device), axis=None).realize()
|
||||
|
||||
# HACK: without contiguous, the conversation mode is broken and the cache is not updated
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1).contiguous() if start_pos > 0 else xv
|
||||
|
||||
# update the cache
|
||||
assert keys.dtype == self.cache_k.dtype and values.dtype == self.cache_v.dtype, f"{keys.dtype=}, {values.dtype=}, {self.cache_k.dtype=}, {self.cache_v.dtype=}"
|
||||
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None))).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None))).realize()
|
||||
assert xk.dtype == self.cache_k.dtype and xv.dtype == self.cache_v.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_k.dtype=}, {self.cache_v.dtype=}"
|
||||
self.cache_k.shrink((None, (start_pos, start_pos+seqlen), None, None)).assign(xk.contiguous()).realize()
|
||||
self.cache_v.shrink((None, (start_pos, start_pos+seqlen), None, None)).assign(xv.contiguous()).realize()
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
|
||||
values = self.cache_v.shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
|
||||
|
||||
44
test/test_setitem.py
Normal file
44
test/test_setitem.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
import numpy as np
|
||||
|
||||
class TestSetitem(unittest.TestCase):
|
||||
def test_simple_setitem(self):
|
||||
t = Tensor.zeros(6, 6).contiguous().realize()
|
||||
t[2:4, 3:5] = Tensor.ones(2, 2)
|
||||
n = np.zeros((6, 6))
|
||||
n[2:4, 3:5] = np.ones((2, 2))
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
def test_simple_jit_setitem(self):
|
||||
@TinyJit
|
||||
def f(t:Tensor, a:Tensor):
|
||||
t[2:4, 3:5] = a
|
||||
|
||||
for i in range(1, 6):
|
||||
t = Tensor.zeros(6, 6).contiguous().realize()
|
||||
a = Tensor.full((2, 2), fill_value=i, dtype=dtypes.float).contiguous()
|
||||
f(t, a)
|
||||
|
||||
n = np.zeros((6, 6))
|
||||
n[2:4, 3:5] = np.full((2, 2), i)
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
def test_jit_setitem_variable_offset(self):
|
||||
@TinyJit
|
||||
def f(t:Tensor, a:Tensor, v:Variable):
|
||||
t.shrink(((v,v+1), None)).assign(a).realize()
|
||||
|
||||
t = Tensor.zeros(6, 6).contiguous().realize()
|
||||
n = np.zeros((6, 6))
|
||||
|
||||
for i in range(6):
|
||||
v = Variable("v", 0, 6).bind(i)
|
||||
a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous()
|
||||
n[i, :] = i+1
|
||||
f(t, a, v)
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -79,8 +79,11 @@ def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
|
||||
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
||||
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
|
||||
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view)), membufs[1:]
|
||||
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
|
||||
@@ -64,8 +64,8 @@ class LazyBuffer:
|
||||
def is_realized(self) -> bool: return self.base.realized is not None
|
||||
|
||||
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
||||
assert (self.base is self) or (self.st.contiguous and self.size == self.base.size), f"assign target must be contiguous {self.st}"
|
||||
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self.base))
|
||||
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
||||
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
|
||||
|
||||
def contiguous(self):
|
||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||
|
||||
@@ -510,9 +510,13 @@ class Tensor:
|
||||
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
|
||||
return ret
|
||||
|
||||
def __setitem__(self,indices,v):
|
||||
def __setitem__(self, indices, v:Tensor):
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"): return self.__getitem__(indices).assign(v)
|
||||
raise NotImplementedError("not implemented yet")
|
||||
# TODO: support python const v
|
||||
# TODO: broadcast v to the shape here, refactor for const v and one way broadcast_shape
|
||||
assign_to = self.__getitem__(indices)
|
||||
# NOTE: contiguous to prevent const folding.
|
||||
return assign_to.assign(v._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()).realize()
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user