mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* preallocate all realized buffers
* contiguous
* work
* comment that out
* move to schedule
* better
* correct fix
* just buffer
* disk bufs
* fixes disk tensor stuff
* fix symbolic stuff
* fix multi
* 162 failures
* bugfixes
* don't check that anymore
* fix schedule tests
* mnist should be contiguious
* type and buffer
* fix tests
* shrink axis correction
* mypy fixes
* tests skips
* same 37 failures
* dedup
* no shrink in the graph
* 29 failures
* skips
* fix custom kernel
* fix training
* those optimizations aren't supported currently
* simpler
* more correct
* tests
* 14 failures
* works
* fix that test
* broken
* 11 failures
* only kernel counts left
* fixes
* all tests pass
* remove tensor_map
* op test
* 200 -> 230
* test fixes
* fixes
* revert test_tiny thing
* guard
* revert that
* test tiny passes
* no contigs there
* base realize back
* Revert "no contigs there"
This reverts commit c45bb9fcfd.
* revert that
* chop many assigns
* 12 failures
* fix tests
* tests
* apply after
* pre-commit
* remove old code
* delete that
* fix types
* remove extra contig
* fix dataloader
* torch fix
* disk fix
* update kernel fusion numbres
* runs on amd
* restore kernel count
* add that rule back
* that
* disable that
* wrong
* add the correct rule for that folding
* more tests
* guard c1.arg
* no newlines
* realize those
* split into a different file
* remove detach/contig back
* skip 2
* update that
118 lines
4.4 KiB
Python
118 lines
4.4 KiB
Python
import unittest
|
|
from tinygrad import Tensor, dtypes
|
|
from tinygrad.tensor import _METADATA
|
|
from tinygrad.engine.realize import capturing
|
|
from tinygrad.helpers import Context
|
|
|
|
@unittest.skip("tensor metadata is no longer supported")
|
|
class TestTensorMetadata(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
_METADATA.set(None)
|
|
self._ctx = Context(SCACHE=0)
|
|
self._ctx.__enter__()
|
|
def tearDown(self) -> None:
|
|
self._ctx.__exit__(None, None, None)
|
|
|
|
@unittest.skip("why would this be true?")
|
|
def test_exclude_noop_metadata(self):
|
|
a = Tensor.rand(4, 4)*1
|
|
self.assertEqual(a.uop.metadata[0].name, "__mul__")
|
|
k = a.schedule()[-1]
|
|
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
|
|
|
@unittest.skip("metadata not reaching kernel schedule")
|
|
def test_exclude_const_metadata(self):
|
|
a = Tensor.arange(4)
|
|
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
|
|
sched = Tensor.schedule(a, b)
|
|
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
|
|
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
|
|
|
|
def test_matmul(self):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
W = Tensor.rand(3, 3, requires_grad=True)
|
|
out = x.matmul(W)
|
|
self.assertEqual(out.uop.metadata[0].name, "matmul")
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 1)
|
|
self.assertEqual(si.metadata[0].name, "matmul")
|
|
|
|
def test_relu(self):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
out = x.relu()
|
|
self.assertEqual(out.uop.metadata[0].name, "relu")
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 1)
|
|
self.assertEqual(si.metadata[0].name, "relu")
|
|
|
|
@unittest.skip("assign metadata no longer captured")
|
|
def test_assign(self):
|
|
x = Tensor.empty(10, 10).realize()
|
|
x.assign(Tensor.ones(10, 10).contiguous())
|
|
si = x.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 1)
|
|
self.assertEqual(si.metadata[0].name, "assign")
|
|
|
|
def test_complex(self):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
y = Tensor.rand(3, requires_grad=True)
|
|
out = x.relu() * y.sigmoid()
|
|
self.assertEqual(out.uop.metadata[0].name, "__mul__")
|
|
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
|
|
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(len(si.metadata), 3)
|
|
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
|
|
|
@unittest.skip("flaky")
|
|
def test_complex_backward(self):
|
|
x = Tensor.rand(3, requires_grad=True).realize()
|
|
y = Tensor.rand(3, requires_grad=True).realize()
|
|
out = (x.relu() * y.sigmoid()).sum()
|
|
self.assertEqual(out.uop.metadata[0].name, "sum")
|
|
out.backward()
|
|
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
|
|
#self.assertTrue(x.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
|
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
|
#self.assertTrue(y.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
|
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
|
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
|
# skip numpy, this is schedule cache
|
|
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
|
|
#bw = [m for m in si.metadata if m.backward]
|
|
#self.assertEqual(len(bw), 1)
|
|
#self.assertEqual(bw[0].name, "sigmoid")
|
|
|
|
def test_tracemeta_0(self):
|
|
with Context(TRACEMETA=0):
|
|
x = Tensor.rand(3, requires_grad=True)
|
|
y = Tensor.rand(3, requires_grad=True)
|
|
out = (x.relu() * y.sigmoid()).sum()
|
|
self.assertIsNone(out.uop.metadata)
|
|
self.assertIsNone(out.uop.src[0].metadata)
|
|
si = out.schedule()[-1]
|
|
self.assertEqual(si.metadata, ())
|
|
|
|
def _has_metadata(self, h, name):
|
|
items = []
|
|
capturing.append(type("", (), {"add": lambda _, ei: items.append(ei)})())
|
|
try: h.realize()
|
|
finally: capturing.clear()
|
|
return any(m.name == name for ei in items for m in ei.metadata)
|
|
|
|
def test_metadata_survives_realize_pending_assign(self):
|
|
shared = Tensor.rand(4)
|
|
c = Tensor.zeros(8).contiguous().realize()
|
|
c[:4].assign(shared)
|
|
self.assertTrue(self._has_metadata(c[:4].relu(), "relu"))
|
|
|
|
@unittest.expectedFailure
|
|
def test_metadata_lost_realize_pending_assign(self):
|
|
shared = Tensor.rand(4)
|
|
c = Tensor.zeros(8).contiguous().realize()
|
|
c[:4].assign(shared)
|
|
self.assertTrue(self._has_metadata((c[:4] + shared).relu(), "relu"))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|