diff --git a/test/test_schedule.py b/test/test_schedule.py index 6a1191fc51..415c036139 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2151,9 +2151,9 @@ class TestBigGraph(unittest.TestCase): a = Tensor.empty(4, 4, dtype=dtypes.int) sink = tensor_rewrite(a*0) assert UPat(Ops.CONST, arg=0).match(sink, {}) - self.assertIs(tensor_rewrite(a*1), a.lazydata) - self.assertIs(tensor_rewrite(a+0), a.lazydata) - self.assertIs(tensor_rewrite(a//1), a.lazydata) + self.assertIs(tensor_rewrite(a*1).base, a.lazydata.base) + self.assertIs(tensor_rewrite(a+0).base, a.lazydata.base) + self.assertIs(tensor_rewrite(a//1).base, a.lazydata.base) def test_cast_folding(self): a = Tensor(1.0).cast(dtypes.int) @@ -2310,7 +2310,7 @@ class TestCopyFolding(unittest.TestCase): b = a.copy_to_device(a.device) check_schedule(b, 0, filter_sink=False) b = schedule_graph_rewrite(b) - self.assertIs(b, a) + self.assertIs(b.base, a.base) def test_clone(self): a = Tensor.empty(4).lazydata diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index b2e41c9133..859bbd9d56 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -118,5 +118,13 @@ class TestTensorUopRepresentation(unittest.TestCase): #is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,))) is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),))) + def test_empty_buf(self): + a = Tensor.empty(3, 3) + is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) + vi = UOp.variable("i", 1, 3).bind(1) + a = Tensor.empty(3, vi) + is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) + self.assertEqual(a.lazydata.base.realized.size, 9) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 87e8918e06..503a707bd7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -490,8 +490,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if op is Ops.BIND: var, val = arg.unbind() return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val) - # otherwise it's just a VIEW(BUFFER) - return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st) + # otherwise it's just a RESHAPE(BUFFER) + if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}") + return UOp.new_buffer(device, size, dtype).reshape(shape) def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp: # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) @@ -547,7 +548,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return self.src[0].buf_uop @property def buffer(self) -> Buffer: - if self.op is Ops.VIEW: + if self is not self.base: assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous" return self.src[0].buffer assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"