From 4c4d3cb3e349e515b5065bce7eb983bd2d984f20 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 18 Mar 2024 15:33:06 -0700 Subject: [PATCH] restrict assignment to base (#3809) * restrict assignment to base * add some restrictions there * more restrictions --- extra/models/llama.py | 4 ++-- test/test_assign.py | 2 +- tinygrad/features/multi.py | 1 + tinygrad/lazy.py | 8 +++++--- tinygrad/realize.py | 10 +++++++--- tinygrad/tensor.py | 3 +-- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/extra/models/llama.py b/extra/models/llama.py index 49c2e451de..2028a3f738 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -69,8 +69,8 @@ class Attention: self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize() if isinstance(x.device, tuple): # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded - self.cache_k.shard_((xk.device), axis=None) - self.cache_v.shard_((xv.device), axis=None) + self.cache_k.shard_((xk.device), axis=None).realize() + self.cache_v.shard_((xv.device), axis=None).realize() # HACK: without contiguous, the conversation mode is broken and the cache is not updated keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk diff --git a/test/test_assign.py b/test/test_assign.py index 13a9b77015..ef9b53826e 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -178,7 +178,7 @@ class TestAssign(unittest.TestCase): b.realize() ba1 = a.lazydata.base.realized bb1 = b.lazydata.base.realized - with self.assertRaises(RuntimeError): + with self.assertRaises((RuntimeError, AssertionError)): a = a.permute(1,0) a += b a.realize() diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index df0b962b3c..a4ea00d886 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -57,6 +57,7 @@ class MultiLazyBuffer: def is_unrealized_contiguous_const(self): return False # passthroughs + def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True]) def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real) def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real) def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index eb729fa7e9..9410e112d2 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -59,10 +59,12 @@ class LazyBuffer: shape = self.shape if shape is None else shape return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape) + def is_realized(self) -> bool: return self.base.realized is not None + def assign(self, x:LazyBuffer) -> LazyBuffer: - if self.base.realized is not None or self is not self.base: new_self = self - else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False) - return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self)) + assert (self.base is self) or (self.st.contiguous and self.size == self.base.size), f"assign target must be contiguous {self.st}" + return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self.base)) + def contiguous(self): if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const(): ret = self.e(LoadOps.CONTIGUOUS) diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 303acb3f81..0e25206af0 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -107,10 +107,14 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st)) # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it - if buf.op in {LoadOps.CONTIGUOUS, LoadOps.ASSIGN}: + if buf.op is LoadOps.CONTIGUOUS: assert first - return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, - assign_to=buf.srcs[1].base if buf.op is LoadOps.ASSIGN else None) + return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False) + if buf.op is LoadOps.ASSIGN: + assert first + assert buf.srcs[1].base is buf.srcs[1], "assign must be to base" + assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}" + return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1]) # if it's a reduce, we have to change the shapetracker if buf.op in ReduceOps: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4290efd19a..70ae2538aa 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -160,8 +160,7 @@ class Tensor: assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer" assert not x.requires_grad # self requires_grad is okay? - if (self.lazydata.lbs[0].base.realized is None if isinstance(self.lazydata, MultiLazyBuffer) else self.lazydata.base.realized is None): - return self.replace(x) + if not self.lazydata.is_realized(): return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) return self def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)