fix using clone with shrink [pr] (#8724)

* fix using clone with shrink [pr]

* remove extra arg, add test_clone_with_shrink_realized
This commit is contained in:
qazal
2025-01-23 01:28:07 -05:00
committed by GitHub
parent af65331b76
commit 6cb74bb630

View File

@@ -652,19 +652,26 @@ class TestZeroShapeTensor(unittest.TestCase):
def test_clone(self):
a = Tensor.rand(16, 16).realize()
self.assertIsNot(a.lazydata, a.clone().lazydata)
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
b = a.clone()
np.testing.assert_allclose(a.numpy(), b.numpy())
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
a = Tensor.rand(16, 16).mul(5.0).add(5.0)
self.assertIsNot(a.lazydata, a.clone().lazydata)
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
b = a.clone()
np.testing.assert_allclose(a.numpy(), b.numpy())
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
def test_clone_with_shrink(self):
a = Tensor.empty(16, 16)
self.assertIsNot(a.lazydata, a.clone().lazydata)
a = Tensor.rand(16, 16)
b = a.shrink(((2, 10), None)).clone()
b.realize()
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
b = a.shrink(((2, 10), None))
self.assertIsNot(b.lazydata, b.clone().lazydata)
def test_clone_with_shrink_realized(self):
a = Tensor.rand(16, 16).realize()
b = a.shrink(((2, 10), None)).clone()
b.realize()
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
def test_clone_with_grad(self):
a = Tensor.rand(16, 16, requires_grad=True)