diff --git a/test/test_schedule.py b/test/test_schedule.py index 5878b58eaf..407cf864f6 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -347,14 +347,14 @@ class TestSchedule(unittest.TestCase): # a and b share the same underlying device memory self.assertIs(a.lazydata.realized, b.lazydata.realized) - def test_copy_dedups(self): + def test_clone_doesnt_dedup(self): src = Tensor.ones(4).contiguous().realize() a = src.clone() b = src.clone() - sched = check_schedule([a, b], 1, filter_sink=False) + sched = check_schedule([a, b], 2, filter_sink=False) run_schedule(sched) # a and b are assigned to the same device Buffer - self.assertIs(a.lazydata.realized, b.lazydata.realized) + self.assertIsNot(a.lazydata.realized, b.lazydata.realized) # EMPTY is assigned to a unique device Buffer @@ -2337,7 +2337,7 @@ class TestCopyFolding(unittest.TestCase): self.assertIs(b.base, a.base) def test_clone(self): - a = Tensor.empty(4).lazydata + a = Tensor.empty(4) check_schedule(a.clone(), 1, filter_sink=False) # NOTE: moving copy before view might change this @@ -2346,7 +2346,7 @@ class TestCopyFolding(unittest.TestCase): view = a.shrink(((0, 2),)) b = view.clone() # NOTE: this was sort of a bug making this 2 - run_schedule(check_schedule(b, 3, filter_sink=False)) + run_schedule(check_schedule(b, 2, filter_sink=False)) self.assertEqual(b.lazydata.base.buffer.size, 2) self.assertEqual(b.lazydata.size, 2) self.assertListEqual(b.tolist(), [0, 1]) @@ -2356,7 +2356,7 @@ class TestCopyFolding(unittest.TestCase): view = a.reshape(2, 1).expand(2, 2) b = view.clone() run_schedule(check_schedule(b, 2, filter_sink=False)) - self.assertEqual(b.lazydata.base.buffer.size, 2) + self.assertEqual(b.lazydata.base.buffer.size, 4) self.assertEqual(b.lazydata.size, 4) self.assertListEqual(b.tolist(), [[0, 0], [1, 1]]) diff --git a/test/test_tensor.py b/test/test_tensor.py index d99a0635b2..ea29c4e4a2 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -688,7 +688,7 @@ class TestZeroShapeTensor(unittest.TestCase): np.testing.assert_allclose(a.numpy(), b.numpy()) self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) - a = Tensor.rand(16, 16).mul(5.0).add(5.0) + a = Tensor.rand(16, 16).mul(5.0).add(5.0).realize() b = a.clone() np.testing.assert_allclose(a.numpy(), b.numpy()) self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6f699fa51e..b0c291380f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -291,7 +291,6 @@ class Tensor(MathTrait): assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" - assert not x.requires_grad # self requires_grad is okay? self.lazydata = self.lazydata.assign(x.lazydata) return self @@ -366,9 +365,9 @@ class Tensor(MathTrait): """ Creates a clone of this tensor allocating a separate buffer for the data. """ - ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad) + ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype) if self.grad is not None: ret.grad = self.grad.clone() - return ret + return ret.assign(self) def to(self, device:str|tuple[str, ...]|None) -> Tensor: """ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 7b1c9ecd35..fc01d89853 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -493,8 +493,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # TODO: this contiguous should not be required!!! inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self.contiguous(),), arg=arg) return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device)) - #return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg) - def clone(self) -> UOp: return self.copy_to_device(self.device) def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg) @property def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)