diff --git a/test/test_search.py b/test/test_search.py index f893230a81..be9395ce83 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -2,7 +2,7 @@ import unittest from tinygrad.codegen.linearizer import Linearizer from tinygrad.realize import create_schedule -from tinygrad.features.search import time_linearizer +from tinygrad.features.search import time_linearizer, bufs_from_lin from tinygrad.device import Compiled, Device, Buffer from tinygrad.ops import LoadOps from tinygrad.tensor import Tensor @@ -12,10 +12,19 @@ class TestTimeLinearizer(unittest.TestCase): if not isinstance(Device[Device.DEFAULT], Compiled): raise unittest.SkipTest("only test for compiled backends") def test_reasonable_time(self): - si = [si for si in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if si.ast.op not in LoadOps][0] + si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0] rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs] tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10) assert tm > 0 and tm != float('inf') + def test_bufs_from_lin(self): + si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0] + rawbufs = bufs_from_lin(lin:=Linearizer(si.ast)) + assert len(rawbufs) == len(lin.membufs) + assert all(r is not None for r in rawbufs) + assert all(isinstance(r, Buffer) for r in rawbufs) + assert all(r.size > 0 for r in rawbufs) + + if __name__ == '__main__': unittest.main()