mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
restrict assignment to base (#3809)
* restrict assignment to base * add some restrictions there * more restrictions
This commit is contained in:
@@ -69,8 +69,8 @@ class Attention:
|
|||||||
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
||||||
if isinstance(x.device, tuple):
|
if isinstance(x.device, tuple):
|
||||||
# TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
|
# TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
|
||||||
self.cache_k.shard_((xk.device), axis=None)
|
self.cache_k.shard_((xk.device), axis=None).realize()
|
||||||
self.cache_v.shard_((xv.device), axis=None)
|
self.cache_v.shard_((xv.device), axis=None).realize()
|
||||||
|
|
||||||
# HACK: without contiguous, the conversation mode is broken and the cache is not updated
|
# HACK: without contiguous, the conversation mode is broken and the cache is not updated
|
||||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class TestAssign(unittest.TestCase):
|
|||||||
b.realize()
|
b.realize()
|
||||||
ba1 = a.lazydata.base.realized
|
ba1 = a.lazydata.base.realized
|
||||||
bb1 = b.lazydata.base.realized
|
bb1 = b.lazydata.base.realized
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises((RuntimeError, AssertionError)):
|
||||||
a = a.permute(1,0)
|
a = a.permute(1,0)
|
||||||
a += b
|
a += b
|
||||||
a.realize()
|
a.realize()
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class MultiLazyBuffer:
|
|||||||
def is_unrealized_contiguous_const(self): return False
|
def is_unrealized_contiguous_const(self): return False
|
||||||
|
|
||||||
# passthroughs
|
# passthroughs
|
||||||
|
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])
|
||||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||||
def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
|
def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
|
||||||
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
||||||
|
|||||||
@@ -59,10 +59,12 @@ class LazyBuffer:
|
|||||||
shape = self.shape if shape is None else shape
|
shape = self.shape if shape is None else shape
|
||||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape)
|
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape)
|
||||||
|
|
||||||
|
def is_realized(self) -> bool: return self.base.realized is not None
|
||||||
|
|
||||||
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
||||||
if self.base.realized is not None or self is not self.base: new_self = self
|
assert (self.base is self) or (self.st.contiguous and self.size == self.base.size), f"assign target must be contiguous {self.st}"
|
||||||
else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False)
|
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self.base))
|
||||||
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self))
|
|
||||||
def contiguous(self):
|
def contiguous(self):
|
||||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||||
ret = self.e(LoadOps.CONTIGUOUS)
|
ret = self.e(LoadOps.CONTIGUOUS)
|
||||||
|
|||||||
@@ -107,10 +107,14 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var
|
|||||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
|
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
|
||||||
|
|
||||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||||
if buf.op in {LoadOps.CONTIGUOUS, LoadOps.ASSIGN}:
|
if buf.op is LoadOps.CONTIGUOUS:
|
||||||
assert first
|
assert first
|
||||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False,
|
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
|
||||||
assign_to=buf.srcs[1].base if buf.op is LoadOps.ASSIGN else None)
|
if buf.op is LoadOps.ASSIGN:
|
||||||
|
assert first
|
||||||
|
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
||||||
|
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||||
|
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1])
|
||||||
|
|
||||||
# if it's a reduce, we have to change the shapetracker
|
# if it's a reduce, we have to change the shapetracker
|
||||||
if buf.op in ReduceOps:
|
if buf.op in ReduceOps:
|
||||||
|
|||||||
@@ -160,8 +160,7 @@ class Tensor:
|
|||||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||||
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
||||||
assert not x.requires_grad # self requires_grad is okay?
|
assert not x.requires_grad # self requires_grad is okay?
|
||||||
if (self.lazydata.lbs[0].base.realized is None if isinstance(self.lazydata, MultiLazyBuffer) else self.lazydata.base.realized is None):
|
if not self.lazydata.is_realized(): return self.replace(x)
|
||||||
return self.replace(x)
|
|
||||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||||
return self
|
return self
|
||||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user