Files
tinygrad/test/test_winograd.py
George Hotz 275951b730 clean up a few parents -> toposort [pr] (#7984)
* clean up a few parents -> toposort [pr]

* rename to old_parents + sched tests

* a few more

* that one

* second to last

* final
2024-12-02 15:59:31 +08:00

82 lines
2.9 KiB
Python

import unittest
from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.ops import Ops
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
class TestWinograd(unittest.TestCase):
def setUp(self):
self.old = WINO.value
WINO.value = 1
def tearDown(self):
WINO.value = self.old
def test_speed(self):
x = Tensor.empty(1,4,9,9)
w = Tensor.empty(4,4,3,3)
with Timing("running conv: "):
out = Tensor.conv2d(x, w)
with Timing("scheduling: "):
sched = create_schedule([out.lazydata])
for i,s in enumerate(sched):
if s.ast.op is not Ops.SINK: continue
ops = s.ast.toposort
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.linearize()
assert len(l.sts) <= 256 # just the current value to prevent regression
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
for st in l.sts:
assert len(st.views) <= 2, "too many views in winograd"
if DEBUG >= 3:
print(f"{len(st.views):3d} views")
for v in st.views: print(v)
def test_profile(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
with Profiling(enabled=not CI, sort='time'):
out = Tensor.conv2d(x,w).realize()
out.numpy()
def test_four_kernels(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
GlobalCounters.reset()
out = Tensor.conv2d(x,w).realize()
assert GlobalCounters.kernel_count == 4
out.numpy()
@unittest.skipIf(getenv("PTX"), "winograd uses too much in PTX")
def test_counters(self):
IC, OC, X, Y = 4,4,9,9
#OC, IC, X, Y = 512, 256, 8, 8
x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
GlobalCounters.reset()
Tensor.conv2d(x,w).realize()
ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
WINO.value = 0
GlobalCounters.reset()
Tensor.conv2d(x,w).realize()
ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now
self.assertLess(mem_ratio, 10)
def test_dtype(self):
IC, OC, X, Y = 4,4,9,9
x,w = Tensor.empty(1,IC,Y,X), Tensor.empty(OC,IC,3,3)
self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.default_float)
x,w = Tensor.empty(1,IC,Y,X,dtype=dtypes.half), Tensor.empty(OC,IC,3,3,dtype=dtypes.half)
self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.half)
if __name__ == '__main__':
unittest.main(verbosity=2)