restrict assignment to base (#3809)

* restrict assignment to base

* add some restrictions there

* more restrictions
This commit is contained in:
George Hotz
2024-03-18 15:33:06 -07:00
committed by GitHub
parent 20681d5c4a
commit 4c4d3cb3e3
6 changed files with 17 additions and 11 deletions

View File

@@ -69,8 +69,8 @@ class Attention:
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
if isinstance(x.device, tuple):
# TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
self.cache_k.shard_((xk.device), axis=None)
self.cache_v.shard_((xv.device), axis=None)
self.cache_k.shard_((xk.device), axis=None).realize()
self.cache_v.shard_((xv.device), axis=None).realize()
# HACK: without contiguous, the conversation mode is broken and the cache is not updated
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk

View File

@@ -178,7 +178,7 @@ class TestAssign(unittest.TestCase):
b.realize()
ba1 = a.lazydata.base.realized
bb1 = b.lazydata.base.realized
with self.assertRaises(RuntimeError):
with self.assertRaises((RuntimeError, AssertionError)):
a = a.permute(1,0)
a += b
a.realize()

View File

@@ -57,6 +57,7 @@ class MultiLazyBuffer:
def is_unrealized_contiguous_const(self): return False
# passthroughs
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)

View File

@@ -59,10 +59,12 @@ class LazyBuffer:
shape = self.shape if shape is None else shape
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape)
def is_realized(self) -> bool: return self.base.realized is not None
def assign(self, x:LazyBuffer) -> LazyBuffer:
if self.base.realized is not None or self is not self.base: new_self = self
else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False)
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self))
assert (self.base is self) or (self.st.contiguous and self.size == self.base.size), f"assign target must be contiguous {self.st}"
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self.base))
def contiguous(self):
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
ret = self.e(LoadOps.CONTIGUOUS)

View File

@@ -107,10 +107,14 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op in {LoadOps.CONTIGUOUS, LoadOps.ASSIGN}:
if buf.op is LoadOps.CONTIGUOUS:
assert first
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False,
assign_to=buf.srcs[1].base if buf.op is LoadOps.ASSIGN else None)
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
if buf.op is LoadOps.ASSIGN:
assert first
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1])
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:

View File

@@ -160,8 +160,7 @@ class Tensor:
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if (self.lazydata.lbs[0].base.realized is None if isinstance(self.lazydata, MultiLazyBuffer) else self.lazydata.base.realized is None):
return self.replace(x)
if not self.lazydata.is_realized(): return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
return self
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)