Files
tinygrad/test/null/test_tensor_metadata.py
George Hotz dc77b3318b move files that pass with NULL=1 to test/null (#14508)
* move files that pass with NULL=1 to test/null

* fix windows

* cpu 0

* bugfix + durations
2026-02-03 13:52:36 +08:00

95 lines
3.6 KiB
Python

import unittest
from tinygrad import Tensor, dtypes
from tinygrad.tensor import _METADATA
from tinygrad.helpers import Context
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None:
_METADATA.set(None)
self._ctx = Context(SCACHE=0)
self._ctx.__enter__()
def tearDown(self) -> None:
self._ctx.__exit__(None, None, None)
@unittest.skip("why would this be true?")
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.uop.metadata[0].name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
@unittest.skip("metadata not reaching kernel schedule")
def test_exclude_const_metadata(self):
a = Tensor.arange(4)
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
sched = Tensor.schedule(a, b)
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
def test_matmul(self):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
self.assertEqual(out.uop.metadata[0].name, "matmul")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "matmul")
def test_relu(self):
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
self.assertEqual(out.uop.metadata[0].name, "relu")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
@unittest.skip("assign metadata no longer captured")
def test_assign(self):
x = Tensor.empty(10, 10).realize()
x.assign(Tensor.ones(10, 10).contiguous())
si = x.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "assign")
def test_complex(self):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.uop.metadata[0].name, "__mul__")
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.uop.metadata[0].name, "sum")
out.backward()
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
#self.assertTrue(x.grad.uop.metadata[0].backward) # TODO: backward flag is False
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
#self.assertTrue(y.grad.uop.metadata[0].backward) # TODO: backward flag is False
si = Tensor.schedule(out, x.grad, y.grad)[-1]
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
# skip numpy, this is schedule cache
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
#bw = [m for m in si.metadata if m.backward]
#self.assertEqual(len(bw), 1)
#self.assertEqual(bw[0].name, "sigmoid")
def test_tracemeta_0(self):
with Context(TRACEMETA=0):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = (x.relu() * y.sigmoid()).sum()
self.assertIsNone(out.uop.metadata)
self.assertIsNone(out.uop.src[0].metadata)
si = out.schedule()[-1]
self.assertEqual(si.metadata, ())
if __name__ == '__main__':
unittest.main()