mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix tests (#10493)
This commit is contained in:
@@ -347,14 +347,14 @@ class TestSchedule(unittest.TestCase):
|
||||
# a and b share the same underlying device memory
|
||||
self.assertIs(a.lazydata.realized, b.lazydata.realized)
|
||||
|
||||
def test_copy_dedups(self):
|
||||
def test_clone_doesnt_dedup(self):
|
||||
src = Tensor.ones(4).contiguous().realize()
|
||||
a = src.clone()
|
||||
b = src.clone()
|
||||
sched = check_schedule([a, b], 1, filter_sink=False)
|
||||
sched = check_schedule([a, b], 2, filter_sink=False)
|
||||
run_schedule(sched)
|
||||
# a and b are assigned to the same device Buffer
|
||||
self.assertIs(a.lazydata.realized, b.lazydata.realized)
|
||||
self.assertIsNot(a.lazydata.realized, b.lazydata.realized)
|
||||
|
||||
# EMPTY is assigned to a unique device Buffer
|
||||
|
||||
@@ -2337,7 +2337,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
self.assertIs(b.base, a.base)
|
||||
|
||||
def test_clone(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
a = Tensor.empty(4)
|
||||
check_schedule(a.clone(), 1, filter_sink=False)
|
||||
|
||||
# NOTE: moving copy before view might change this
|
||||
@@ -2346,7 +2346,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
view = a.shrink(((0, 2),))
|
||||
b = view.clone()
|
||||
# NOTE: this was sort of a bug making this 2
|
||||
run_schedule(check_schedule(b, 3, filter_sink=False))
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 2)
|
||||
self.assertEqual(b.lazydata.size, 2)
|
||||
self.assertListEqual(b.tolist(), [0, 1])
|
||||
@@ -2356,7 +2356,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
view = a.reshape(2, 1).expand(2, 2)
|
||||
b = view.clone()
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 2)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 4)
|
||||
self.assertEqual(b.lazydata.size, 4)
|
||||
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
|
||||
|
||||
|
||||
@@ -688,7 +688,7 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
a = Tensor.rand(16, 16).mul(5.0).add(5.0)
|
||||
a = Tensor.rand(16, 16).mul(5.0).add(5.0).realize()
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
@@ -291,7 +291,6 @@ class Tensor(MathTrait):
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
return self
|
||||
|
||||
@@ -366,9 +365,9 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
Creates a clone of this tensor allocating a separate buffer for the data.
|
||||
"""
|
||||
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
|
||||
ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype)
|
||||
if self.grad is not None: ret.grad = self.grad.clone()
|
||||
return ret
|
||||
return ret.assign(self)
|
||||
|
||||
def to(self, device:str|tuple[str, ...]|None) -> Tensor:
|
||||
"""
|
||||
|
||||
@@ -493,8 +493,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# TODO: this contiguous should not be required!!!
|
||||
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self.contiguous(),), arg=arg)
|
||||
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
|
||||
#return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg)
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device)
|
||||
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
|
||||
@property
|
||||
def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)
|
||||
|
||||
Reference in New Issue
Block a user