mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* move files that pass with NULL=1 to test/null * fix windows * cpu 0 * bugfix + durations
95 lines
3.6 KiB
Python
95 lines
3.6 KiB
Python
import unittest
|
|
from tinygrad import Tensor, dtypes
|
|
from tinygrad.tensor import _METADATA
|
|
from tinygrad.helpers import Context
|
|
|
|
class TestTensorMetadata(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
_METADATA.set(None)
|
|
self._ctx = Context(SCACHE=0)
|
|
self._ctx.__enter__()
|
|
def tearDown(self) -> None:
|
|
self._ctx.__exit__(None, None, None)
|
|
|
|
@unittest.skip("why would this be true?")
|
|
def test_exclude_noop_metadata(self):
|
|
a = Tensor.rand(4, 4)*1
|
|
self.assertEqual(a.uop.metadata[0].name, "__mul__")
|
|
k = a.schedule()[-1]
|
|
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
|
|
|
@unittest.skip("metadata not reaching kernel schedule")
|
|
def test_exclude_const_metadata(self):
|
|
a = Tensor.arange(4)
|
|
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
|
|
sched = Tensor.schedule(a, b)
|
|
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
|
|
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
|
|
|
|
def test_matmul(self):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
W = Tensor.rand(3, 3, requires_grad=True)
|
|
out = x.matmul(W)
|
|
self.assertEqual(out.uop.metadata[0].name, "matmul")
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 1)
|
|
self.assertEqual(si.metadata[0].name, "matmul")
|
|
|
|
def test_relu(self):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
out = x.relu()
|
|
self.assertEqual(out.uop.metadata[0].name, "relu")
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 1)
|
|
self.assertEqual(si.metadata[0].name, "relu")
|
|
|
|
@unittest.skip("assign metadata no longer captured")
|
|
def test_assign(self):
|
|
x = Tensor.empty(10, 10).realize()
|
|
x.assign(Tensor.ones(10, 10).contiguous())
|
|
si = x.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 1)
|
|
self.assertEqual(si.metadata[0].name, "assign")
|
|
|
|
def test_complex(self):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
y = Tensor.rand(3, requires_grad=True)
|
|
out = x.relu() * y.sigmoid()
|
|
self.assertEqual(out.uop.metadata[0].name, "__mul__")
|
|
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
|
|
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 3)
|
|
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
|
|
|
def test_complex_backward(self):
|
|
x = Tensor.rand(3, requires_grad=True).realize()
|
|
y = Tensor.rand(3, requires_grad=True).realize()
|
|
out = (x.relu() * y.sigmoid()).sum()
|
|
self.assertEqual(out.uop.metadata[0].name, "sum")
|
|
out.backward()
|
|
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
|
|
#self.assertTrue(x.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
|
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
|
#self.assertTrue(y.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
|
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
|
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
|
# skip numpy, this is schedule cache
|
|
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
|
|
#bw = [m for m in si.metadata if m.backward]
|
|
#self.assertEqual(len(bw), 1)
|
|
#self.assertEqual(bw[0].name, "sigmoid")
|
|
|
|
def test_tracemeta_0(self):
|
|
with Context(TRACEMETA=0):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
y = Tensor.rand(3, requires_grad=True)
|
|
out = (x.relu() * y.sigmoid()).sum()
|
|
self.assertIsNone(out.uop.metadata)
|
|
self.assertIsNone(out.uop.src[0].metadata)
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(si.metadata, ())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|