fix up contiguous (#1978)

This commit is contained in:
George Hotz
2023-10-05 07:22:05 -07:00
committed by GitHub
parent c99fa58dd2
commit 21a2c5df73
3 changed files with 17 additions and 18 deletions

View File

@@ -312,5 +312,15 @@ class TestSchedule(unittest.TestCase):
out = bb(x)
check_schedule(out, 4)
def test_contiguous_while_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.contiguous()
check_schedule(out, 1, filter_loadops=False)
def test_contiguous_while_not_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.permute(0,2,3,1).contiguous()
check_schedule(out, 2, filter_loadops=False)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -161,8 +161,9 @@ class LazyBuffer:
if seen is None: seen = set()
if self in seen or self.realized or self.is_unrealized_const(): return []
seen.add(self)
if self.optype is MovementOps: return self.base.schedule(seen)
if self.base != self: return self.base.schedule(seen)
# rewrite unbased CONTIGUOUS into UnaryOps.NOOP
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
@@ -175,12 +176,6 @@ class LazyBuffer:
# TODO: this belongs in the schedule in some way
self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
# contiguous can be a copy. must do this after the image hack
if self.op.op == LoadOps.CONTIGUOUS:
src = cast(LazyBuffer, self.op.src[0])
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
op = self.op
# run the ast and log the op
op, base_bufs = _replace_bufferops(op)
return ret + [(op, self, tuple(base_bufs))]
@@ -198,6 +193,11 @@ class LazyBuffer:
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
# this will turn into nothing, it's based and a copy
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals, base=self.base)
# real contiguous, this will turn into a UnaryOps.NOOP
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals)
@staticmethod

View File

@@ -20,11 +20,6 @@ def fix_schedule_for_images(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBu
if isinstance(buffers[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
buffers[b.arg.idx-1].dtype = dtypes.float32
# fix the contiguous dtype, no cast required
for op,out,buffers in schedule:
if op.op == LoadOps.CONTIGUOUS and out.dtype != buffers[0].dtype:
out.dtype = buffers[0].dtype = dtypes.float32
# now fix up the schedule to reflect the new dtypes
fixed_schedule = []
for op,out,buffers in schedule:
@@ -82,11 +77,6 @@ def _realize_rand(buffer: LazyBuffer) -> None:
# *** one op LoadOps ***
def _realize_contiguous(buffer: LazyBuffer, src: LazyBuffer) -> None:
# this is just a copy now, if it's not a copy schedule will handle it
buffer.realized = src.realized
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}"
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
@@ -110,7 +100,6 @@ def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.EMPTY: _realize_empty,
LoadOps.RAND: _realize_rand,
LoadOps.CONTIGUOUS: _realize_contiguous,
LoadOps.FROM: _realize_from,
LoadOps.CUSTOM: _realize_custom,
}