mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix using clone with shrink [pr] (#8724)
* fix using clone with shrink [pr] * remove extra arg, add test_clone_with_shrink_realized
This commit is contained in:
@@ -652,19 +652,26 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
|
||||
def test_clone(self):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
a = Tensor.rand(16, 16).mul(5.0).add(5.0)
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
def test_clone_with_shrink(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
a = Tensor.rand(16, 16)
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
b = a.shrink(((2, 10), None))
|
||||
self.assertIsNot(b.lazydata, b.clone().lazydata)
|
||||
def test_clone_with_shrink_realized(self):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
def test_clone_with_grad(self):
|
||||
a = Tensor.rand(16, 16, requires_grad=True)
|
||||
|
||||
Reference in New Issue
Block a user