mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
* PoC faster wino compile by catting consts across data expand dim * fix fusions * faster + golf it * noqa 501 * implicit broadcast * Revert "implicit broadcast" This reverts commit 5915a9083d045ec1e6be84dcb492333325d48666. * shorter * shorter * oops * 216 upcasts is probably fine * wino kernel count test * test winograd number of sts * specify device for apply_matrix mat elements
53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
import unittest
|
|
from tinygrad import Tensor, GlobalCounters
|
|
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG
|
|
from tinygrad.ops import LoadOps
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
|
|
class TestWinograd(unittest.TestCase):
|
|
def setUp(self):
|
|
self.old = WINO.value
|
|
WINO.value = 1
|
|
def tearDown(self):
|
|
WINO.value = self.old
|
|
|
|
def test_speed(self):
|
|
x = Tensor.empty(1,4,9,9)
|
|
w = Tensor.empty(4,4,3,3)
|
|
|
|
with Timing("running conv: "):
|
|
out = Tensor.conv2d(x, w)
|
|
|
|
with Timing("scheduling: "):
|
|
sched = out.lazydata.schedule()
|
|
|
|
for i,s in enumerate(sched):
|
|
if s.ast.op in LoadOps: continue
|
|
ops = s.ast.lazyops
|
|
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
|
l = Linearizer(s.ast)
|
|
l.hand_coded_optimizations()
|
|
l.linearize()
|
|
assert len(l.sts) <= 256 # just the current value to prevent regression
|
|
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
|
|
for st in l.sts:
|
|
assert len(st.views) <= 2, "too many views in winograd"
|
|
if DEBUG >= 3:
|
|
print(f"{len(st.views):3d} views")
|
|
for v in st.views: print(v)
|
|
|
|
def test_profile(self):
|
|
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
|
with Profiling(enabled=not CI, sort='time'):
|
|
out = Tensor.conv2d(x,w).realize()
|
|
out.numpy()
|
|
|
|
def test_four_kernels(self):
|
|
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
|
GlobalCounters.reset()
|
|
out = Tensor.conv2d(x,w).realize()
|
|
assert GlobalCounters.kernel_count == 4
|
|
out.numpy()
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2) |