mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
assert set equality in TestTensorMetadata [pr] (#7364)
This commit is contained in:
@@ -709,41 +709,36 @@ class TestInferenceMode(unittest.TestCase):
|
||||
f(x, m, W)
|
||||
|
||||
class TestTensorMetadata(unittest.TestCase):
|
||||
def setUp(self) -> None: _METADATA.set(None)
|
||||
def test_matmul(self):
|
||||
_METADATA.set(None)
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.lazydata.metadata.name, "matmul")
|
||||
s = create_schedule([out.lazydata])
|
||||
self.assertEqual(len(s[-1].metadata), 1)
|
||||
self.assertEqual(s[-1].metadata[0].name, "matmul")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
|
||||
def test_relu(self):
|
||||
_METADATA.set(None)
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.lazydata.metadata.name, "relu")
|
||||
s = create_schedule([out.lazydata])
|
||||
self.assertEqual(len(s[-1].metadata), 1)
|
||||
self.assertEqual(s[-1].metadata[0].name, "relu")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
|
||||
def test_complex(self):
|
||||
_METADATA.set(None)
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu() * y.sigmoid()
|
||||
self.assertEqual(out.lazydata.metadata.name, "__mul__")
|
||||
self.assertEqual(out.lazydata.srcs[0].metadata.name, "relu")
|
||||
self.assertEqual(out.lazydata.srcs[1].metadata.name, "sigmoid")
|
||||
s = create_schedule([out.lazydata])
|
||||
self.assertEqual(len(s[-1].metadata), 3)
|
||||
self.assertEqual(s[-1].metadata[0].name, "relu")
|
||||
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
|
||||
self.assertEqual(s[-1].metadata[2].name, "__mul__")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
def test_complex_backward(self):
|
||||
_METADATA.set(None)
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
@@ -753,12 +748,12 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertTrue(x.grad.lazydata.metadata.backward)
|
||||
self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid")
|
||||
self.assertTrue(y.grad.lazydata.metadata.backward)
|
||||
s = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])
|
||||
self.assertEqual(len(s[-1].metadata), 3)
|
||||
self.assertEqual(s[-1].metadata[0].name, "sigmoid")
|
||||
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
|
||||
self.assertTrue(s[-1].metadata[1].backward)
|
||||
self.assertEqual(s[-1].metadata[2].name, "relu")
|
||||
si = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"})
|
||||
bw = [m for m in si.metadata if m.backward]
|
||||
self.assertEqual(len(bw), 1)
|
||||
self.assertEqual(bw[0].name, "sigmoid")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user