mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
restrict assignment to base (#3809)
* restrict assignment to base * add some restrictions there * more restrictions
This commit is contained in:
@@ -69,8 +69,8 @@ class Attention:
|
||||
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
||||
if isinstance(x.device, tuple):
|
||||
# TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
|
||||
self.cache_k.shard_((xk.device), axis=None)
|
||||
self.cache_v.shard_((xv.device), axis=None)
|
||||
self.cache_k.shard_((xk.device), axis=None).realize()
|
||||
self.cache_v.shard_((xv.device), axis=None).realize()
|
||||
|
||||
# HACK: without contiguous, the conversation mode is broken and the cache is not updated
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
||||
|
||||
@@ -178,7 +178,7 @@ class TestAssign(unittest.TestCase):
|
||||
b.realize()
|
||||
ba1 = a.lazydata.base.realized
|
||||
bb1 = b.lazydata.base.realized
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises((RuntimeError, AssertionError)):
|
||||
a = a.permute(1,0)
|
||||
a += b
|
||||
a.realize()
|
||||
|
||||
@@ -57,6 +57,7 @@ class MultiLazyBuffer:
|
||||
def is_unrealized_contiguous_const(self): return False
|
||||
|
||||
# passthroughs
|
||||
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])
|
||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
|
||||
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
||||
|
||||
@@ -59,10 +59,12 @@ class LazyBuffer:
|
||||
shape = self.shape if shape is None else shape
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape)
|
||||
|
||||
def is_realized(self) -> bool: return self.base.realized is not None
|
||||
|
||||
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
||||
if self.base.realized is not None or self is not self.base: new_self = self
|
||||
else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False)
|
||||
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self))
|
||||
assert (self.base is self) or (self.st.contiguous and self.size == self.base.size), f"assign target must be contiguous {self.st}"
|
||||
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self.base))
|
||||
|
||||
def contiguous(self):
|
||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||
ret = self.e(LoadOps.CONTIGUOUS)
|
||||
|
||||
@@ -107,10 +107,14 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op in {LoadOps.CONTIGUOUS, LoadOps.ASSIGN}:
|
||||
if buf.op is LoadOps.CONTIGUOUS:
|
||||
assert first
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False,
|
||||
assign_to=buf.srcs[1].base if buf.op is LoadOps.ASSIGN else None)
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
|
||||
if buf.op is LoadOps.ASSIGN:
|
||||
assert first
|
||||
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
||||
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1])
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
|
||||
@@ -160,8 +160,7 @@ class Tensor:
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if (self.lazydata.lbs[0].base.realized is None if isinstance(self.lazydata, MultiLazyBuffer) else self.lazydata.base.realized is None):
|
||||
return self.replace(x)
|
||||
if not self.lazydata.is_realized(): return self.replace(x)
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
return self
|
||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user