From 21a2c5df7359b4e626666d59434bd0813088648e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 5 Oct 2023 07:22:05 -0700 Subject: [PATCH] fix up contiguous (#1978) --- test/test_schedule.py | 10 ++++++++++ tinygrad/lazy.py | 14 +++++++------- tinygrad/realize.py | 11 ----------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 7ad0f817de..2a90f23421 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -312,5 +312,15 @@ class TestSchedule(unittest.TestCase): out = bb(x) check_schedule(out, 4) + def test_contiguous_while_contiguous(self): + x = Tensor.empty(1, 64, 32, 32) + out = x.contiguous() + check_schedule(out, 1, filter_loadops=False) + + def test_contiguous_while_not_contiguous(self): + x = Tensor.empty(1, 64, 32, 32) + out = x.permute(0,2,3,1).contiguous() + check_schedule(out, 2, filter_loadops=False) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index ea51adc225..b35184a192 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -161,8 +161,9 @@ class LazyBuffer: if seen is None: seen = set() if self in seen or self.realized or self.is_unrealized_const(): return [] seen.add(self) - if self.optype is MovementOps: return self.base.schedule(seen) + if self.base != self: return self.base.schedule(seen) + # rewrite unbased CONTIGUOUS into UnaryOps.NOOP op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src) if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape) @@ -175,12 +176,6 @@ class LazyBuffer: # TODO: this belongs in the schedule in some way self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) - # contiguous can be a copy. must do this after the image hack - if self.op.op == LoadOps.CONTIGUOUS: - src = cast(LazyBuffer, self.op.src[0]) - if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const(): - op = self.op - # run the ast and log the op op, base_bufs = _replace_bufferops(op) return ret + [(op, self, tuple(base_bufs))] @@ -198,6 +193,11 @@ class LazyBuffer: def contiguous(self:LazyBuffer) -> LazyBuffer: if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one + if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const(): + # this will turn into nothing, it's based and a copy + # TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops + return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals, base=self.base) + # real contiguous, this will turn into a UnaryOps.NOOP return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals) @staticmethod diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 020c1a7bb6..9edba50161 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -20,11 +20,6 @@ def fix_schedule_for_images(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBu if isinstance(buffers[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())): buffers[b.arg.idx-1].dtype = dtypes.float32 - # fix the contiguous dtype, no cast required - for op,out,buffers in schedule: - if op.op == LoadOps.CONTIGUOUS and out.dtype != buffers[0].dtype: - out.dtype = buffers[0].dtype = dtypes.float32 - # now fix up the schedule to reflect the new dtypes fixed_schedule = [] for op,out,buffers in schedule: @@ -82,11 +77,6 @@ def _realize_rand(buffer: LazyBuffer) -> None: # *** one op LoadOps *** -def _realize_contiguous(buffer: LazyBuffer, src: LazyBuffer) -> None: - # this is just a copy now, if it's not a copy schedule will handle it - buffer.realized = src.realized - assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}" - def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None: assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}" assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from" @@ -110,7 +100,6 @@ def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None: LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = { LoadOps.EMPTY: _realize_empty, LoadOps.RAND: _realize_rand, - LoadOps.CONTIGUOUS: _realize_contiguous, LoadOps.FROM: _realize_from, LoadOps.CUSTOM: _realize_custom, }