diff --git a/test/test_schedule.py b/test/test_schedule.py index 157a3d7285..877daf994c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2272,6 +2272,15 @@ class TestCopyFolding(unittest.TestCase): b = a.copy_to_device(a.device) check_schedule(b, 0, filter_sink=False) b = schedule_graph_rewrite(b) + # NOTE: Tensor.empty(4) always creates a VIEW(BUFFER) with ShapeTracker((4,)), we simplify this to jsut a BUFFER + # in the scheduler because buffer already has shape (4,) + self.assertIs(b, a.base) + + def test_copy_to_same_device_alt(self): + a = Tensor.empty(4, 4).lazydata + b = a.copy_to_device(a.device) + check_schedule(b, 0, filter_sink=False) + b = schedule_graph_rewrite(b) self.assertIs(b, a) def test_clone(self): @@ -2455,14 +2464,17 @@ class TestUOpBecome(unittest.TestCase): b = Tensor.empty(4, 4) add = a+b check_schedule(add, 1) - assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + # NOTE: realized base is always a flat buffer + assert UPat(Ops.BUFFER).match(add.lazydata.base, {}) + # the Tensor UOp can optionally stack a VIEW on top of BUFFER + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(add.lazydata, {}) def test_new_buffer_view(self): a = Tensor.empty(4, 4) b = Tensor.empty(4, 4) add = (a+b).reshape(8, 2) check_schedule(add, 1) - assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + assert UPat(Ops.BUFFER).match(add.lazydata.base, {}) # VIEW is preserverd after the becomes rewrite. self.assertEqual(add.lazydata.shape, (8, 2)) assert add.lazydata is not add.lazydata.base @@ -2472,7 +2484,7 @@ class TestUOpBecome(unittest.TestCase): b = a*1 assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul check_schedule(b, 0) - assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer) def test_become_const_in_base(self): diff --git a/test/test_uops.py b/test/test_uops.py index 3dbb70ace5..b4a1085f14 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -586,7 +586,6 @@ class TestShapeSpec(unittest.TestCase): assign.realize() self.assertEqual(a.tolist(), [1, 0, 1, 1]) - @unittest.expectedFailure def test_buffer_st(self): a = UOp.new_buffer(Device.DEFAULT, 10, dtypes.float) self.assertEqual(a.st, ShapeTracker.from_shape((10,))) diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index b4d391c2af..b2e41c9133 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -2,7 +2,10 @@ import unittest from tinygrad import Tensor from tinygrad.ops import UPat, Ops, UOp -realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)) +# NOTE: unlike before base for a realized tensor is always a BUFFER +realized_pattern = UPat(Ops.BUFFER) +# after realization, tensor uops become VIEW(BUFFER) +buffer_view_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)) const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),))) def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}" def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat) @@ -19,9 +22,6 @@ class TestTensorMutates(unittest.TestCase): self.assertIsNot(pa, a.lazydata) self.assertIsNot(pb, b.lazydata) self.assertIsNot(pr, ret.lazydata) - # NOTE: this becomes a VIEW(VIEW(BUFFER)) because UOp.view no longer instantly folds contiguous VIEW of the same shape - # this is fine because realized exists on the base. - # TODO: we can make this always be a VIEW(BUFFER) once BUFFER has a ShapeTracker of shape=(N,) for t in [a,b,ret]: is_pattern_uop(t.lazydata.base, realized_pattern) def test_reshape_is_same_parent(self): @@ -32,6 +32,9 @@ class TestTensorMutates(unittest.TestCase): d.realize() is_pattern_uop(d.lazydata.base, realized_pattern) is_pattern_uop(c.lazydata.base, realized_pattern) + # NOTE: we keep movement ops on top of the buffer view + is_pattern_uop(c.lazydata, buffer_view_pattern) + is_pattern_uop(d.lazydata, UPat(Ops.RESHAPE, src=(buffer_view_pattern,))) def test_reshape_is_same_child(self): a = Tensor([1,2,3]) @@ -53,8 +56,8 @@ class TestTensorUopRepresentation(unittest.TestCase): b = Tensor([4.,5,6]).realize() c = a+b print(c.lazydata) - is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern))) - #is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,))))) + #is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern))) + is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,))))) def test_const_pattern(self): a = Tensor(1) @@ -112,8 +115,8 @@ class TestTensorUopRepresentation(unittest.TestCase): c = a.to("TEST") # NOTE: this isn't checked print(c.lazydata) # TODO: COPY on a Tensor becomes a VIEW(COPY), this should be done in the scheduler not in ops - is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,))) - #is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),))) + #is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,))) + is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),))) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9b3dbc5654..79791dd5b9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -90,7 +90,8 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c # SINK is passthrough if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # skip creating buffers for CONST/BIND/DEVICE/BUFFER - if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf + if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf + if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st)) # VIEW is passthrough if buf is not buf.base: cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st)) @@ -387,8 +388,8 @@ sym = symbolic_simple+PatternMatcher([ # remove contiguous if we can just view the buffer (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), - # double contiguous is one contiguous - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]), + # contiguous/buffer is already contiguous + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]), # support for using a contiguous permuted view instead of the parent view if one exists (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index da3a01964b..073e762117 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -286,6 +286,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.MULTI: return ShapeTracker.from_shape( tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))) + if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,)) # these ops define a ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) @@ -499,7 +500,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def base(self) -> UOp: if self.op in GroupOp.Movement: return self.src[0].base - return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self + return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 else self def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st) def _mop(self, op:Ops, arg): @@ -528,7 +529,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: - if self.op is Ops.BUFFER: return self + if self.base.op is Ops.BUFFER: return self.base assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW}, f"buf_uop called on {self.op}" return self.src[0].buf_uop def buf_uop_view(self) -> UOp: return self.buf_uop.view(unwrap(self.st))