Files
tinygrad/test/null/test_tensor_metadata.py
George Hotz 55d3a5def9 preallocate all realized buffers (#14823)
* preallocate all realized buffers

* contiguous

* work

* comment that out

* move to schedule

* better

* correct fix

* just buffer

* disk bufs

* fixes disk tensor stuff

* fix symbolic stuff

* fix multi

* 162 failures

* bugfixes

* don't check that anymore

* fix schedule tests

* mnist should be contiguious

* type and buffer

* fix tests

* shrink axis correction

* mypy fixes

* tests skips

* same 37 failures

* dedup

* no shrink in the graph

* 29 failures

* skips

* fix custom kernel

* fix training

* those optimizations aren't supported currently

* simpler

* more correct

* tests

* 14 failures

* works

* fix that test

* broken

* 11 failures

* only kernel counts left

* fixes

* all tests pass

* remove tensor_map

* op test

* 200 -> 230

* test fixes

* fixes

* revert test_tiny thing

* guard

* revert that

* test tiny passes

* no contigs there

* base realize back

* Revert "no contigs there"

This reverts commit c45bb9fcfd.

* revert that

* chop many assigns

* 12 failures

* fix tests

* tests

* apply after

* pre-commit

* remove old code

* delete that

* fix types

* remove extra contig

* fix dataloader

* torch fix

* disk fix

* update kernel fusion numbres

* runs on amd

* restore kernel count

* add that rule back

* that

* disable that

* wrong

* add the correct rule for that folding

* more tests

* guard c1.arg

* no newlines

* realize those

* split into a different file

* remove detach/contig back

* skip 2

* update that
2026-02-20 20:05:54 +08:00

118 lines
4.4 KiB
Python

import unittest
from tinygrad import Tensor, dtypes
from tinygrad.tensor import _METADATA
from tinygrad.engine.realize import capturing
from tinygrad.helpers import Context
@unittest.skip("tensor metadata is no longer supported")
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None:
_METADATA.set(None)
self._ctx = Context(SCACHE=0)
self._ctx.__enter__()
def tearDown(self) -> None:
self._ctx.__exit__(None, None, None)
@unittest.skip("why would this be true?")
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.uop.metadata[0].name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
@unittest.skip("metadata not reaching kernel schedule")
def test_exclude_const_metadata(self):
a = Tensor.arange(4)
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
sched = Tensor.schedule(a, b)
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
def test_matmul(self):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
self.assertEqual(out.uop.metadata[0].name, "matmul")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "matmul")
def test_relu(self):
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
self.assertEqual(out.uop.metadata[0].name, "relu")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
@unittest.skip("assign metadata no longer captured")
def test_assign(self):
x = Tensor.empty(10, 10).realize()
x.assign(Tensor.ones(10, 10).contiguous())
si = x.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "assign")
def test_complex(self):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.uop.metadata[0].name, "__mul__")
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("flaky")
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.uop.metadata[0].name, "sum")
out.backward()
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
#self.assertTrue(x.grad.uop.metadata[0].backward) # TODO: backward flag is False
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
#self.assertTrue(y.grad.uop.metadata[0].backward) # TODO: backward flag is False
si = Tensor.schedule(out, x.grad, y.grad)[-1]
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
# skip numpy, this is schedule cache
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
#bw = [m for m in si.metadata if m.backward]
#self.assertEqual(len(bw), 1)
#self.assertEqual(bw[0].name, "sigmoid")
def test_tracemeta_0(self):
with Context(TRACEMETA=0):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = (x.relu() * y.sigmoid()).sum()
self.assertIsNone(out.uop.metadata)
self.assertIsNone(out.uop.src[0].metadata)
si = out.schedule()[-1]
self.assertEqual(si.metadata, ())
def _has_metadata(self, h, name):
items = []
capturing.append(type("", (), {"add": lambda _, ei: items.append(ei)})())
try: h.realize()
finally: capturing.clear()
return any(m.name == name for ei in items for m in ei.metadata)
def test_metadata_survives_realize_pending_assign(self):
shared = Tensor.rand(4)
c = Tensor.zeros(8).contiguous().realize()
c[:4].assign(shared)
self.assertTrue(self._has_metadata(c[:4].relu(), "relu"))
@unittest.expectedFailure
def test_metadata_lost_realize_pending_assign(self):
shared = Tensor.rand(4)
c = Tensor.zeros(8).contiguous().realize()
c[:4].assign(shared)
self.assertTrue(self._has_metadata((c[:4] + shared).relu(), "relu"))
if __name__ == '__main__':
unittest.main()