faster wino compile by catting consts across data expand dim (#3293)

* PoC faster wino compile by catting consts across data expand dim

* fix fusions

* faster + golf it

* noqa 501

* implicit broadcast

* Revert "implicit broadcast"

This reverts commit 5915a9083d045ec1e6be84dcb492333325d48666.

* shorter

* shorter

* oops

* 216 upcasts is probably fine

* wino kernel count test

* test winograd number of sts

* specify device for apply_matrix mat elements
This commit is contained in:
David Hou
2024-02-02 00:47:45 -08:00
committed by GitHub
parent cf6f478901
commit aebaab011f
3 changed files with 15 additions and 7 deletions

View File

@@ -332,28 +332,32 @@ class TestHandCodedOpts(unittest.TestCase):
def test_masked_upcast_wino_full(self):
with Context(WINO=1):
x,w = Tensor.rand(1,4,9,9, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
out = Tensor.conv2d(x,w, padding=1)
upcasts = []
wino_schedule = out.lazydata.schedule()
# collect upcasts of tile transform kernels
for i, si in enumerate(out.lazydata.schedule()):
for i, si in enumerate(wino_schedule):
k = Linearizer(si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 100: continue # not a tile transform kernel (there's a permute kernel at the end)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
assert len(upcasts) == 3 # 3 transformation matrices
# TODO: what did this fix?
assert len(wino_schedule) <= 4 # 4 kernels
# this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
out.mean().backward()
for si in x.grad.lazydata.schedule() + w.grad.lazydata.schedule():
backward_schedule = x.grad.lazydata.schedule() + w.grad.lazydata.schedule()
for si in backward_schedule:
k = Linearizer(si.ast)
k.hand_coded_optimizations()
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
# heuristic number to make sure that at least some upcasts but not too many upcasts are being done
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 49
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216
assert len(backward_schedule) <= 13 # just the current number, but it could be better
def test_masked_upcast_many(self):
layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4))

View File

@@ -28,6 +28,7 @@ class TestWinograd(unittest.TestCase):
l = Linearizer(s.ast)
l.hand_coded_optimizations()
l.linearize()
assert len(l.sts) <= 256 # just the current value to prevent regression
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
for st in l.sts:
assert len(st.views) <= 2, "too many views in winograd"