mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
rename lazydata to uop (#10698)
This commit is contained in:
@@ -565,17 +565,17 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
t = Tensor.empty(3, 2, 0)
|
||||
assert t.shape == (3, 2, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
||||
|
||||
t = Tensor.empty(3, 0, 2)
|
||||
assert t.shape == (3, 0, 2)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
||||
|
||||
t = Tensor.empty(0, 0, 0)
|
||||
assert t.shape == (0, 0, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
||||
|
||||
def test_rand(self):
|
||||
t = Tensor.rand(3, 2, 0)
|
||||
@@ -690,24 +690,24 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
a = Tensor.rand(16, 16).mul(5.0).add(5.0).realize()
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
def test_clone_with_shrink(self):
|
||||
a = Tensor.rand(16, 16)
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
def test_clone_with_shrink_realized(self):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
def test_clone_with_grad(self):
|
||||
a = Tensor.rand(16, 16, requires_grad=True)
|
||||
@@ -780,7 +780,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
@unittest.skip("why would this be true?")
|
||||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.lazydata.metadata[0].name, "__mul__")
|
||||
self.assertEqual(a.uop.metadata[0].name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
|
||||
@@ -797,7 +797,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "matmul")
|
||||
self.assertEqual(out.uop.metadata[0].name, "matmul")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
@@ -805,7 +805,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
def test_relu(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "relu")
|
||||
self.assertEqual(out.uop.metadata[0].name, "relu")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
@@ -814,9 +814,9 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu() * y.sigmoid()
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.lazydata.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.lazydata.src[1].metadata[0].name, "sigmoid")
|
||||
self.assertEqual(out.uop.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
@@ -825,12 +825,12 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "sum")
|
||||
self.assertEqual(out.uop.metadata[0].name, "sum")
|
||||
out.backward()
|
||||
self.assertEqual(x.grad.lazydata.metadata[0].name, "relu")
|
||||
self.assertTrue(x.grad.lazydata.metadata[0].backward)
|
||||
self.assertEqual(y.grad.lazydata.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.lazydata.metadata[0].backward)
|
||||
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
|
||||
self.assertTrue(x.grad.uop.metadata[0].backward)
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
|
||||
|
||||
Reference in New Issue
Block a user