mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
faster wino compile by catting consts across data expand dim (#3293)
* PoC faster wino compile by catting consts across data expand dim * fix fusions * faster + golf it * noqa 501 * implicit broadcast * Revert "implicit broadcast" This reverts commit 5915a9083d045ec1e6be84dcb492333325d48666. * shorter * shorter * oops * 216 upcasts is probably fine * wino kernel count test * test winograd number of sts * specify device for apply_matrix mat elements
This commit is contained in:
@@ -332,28 +332,32 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
|
||||
def test_masked_upcast_wino_full(self):
|
||||
with Context(WINO=1):
|
||||
x,w = Tensor.rand(1,4,9,9, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
upcasts = []
|
||||
wino_schedule = out.lazydata.schedule()
|
||||
# collect upcasts of tile transform kernels
|
||||
for i, si in enumerate(out.lazydata.schedule()):
|
||||
for i, si in enumerate(wino_schedule):
|
||||
k = Linearizer(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
|
||||
if len(k.bufs) < 100: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
|
||||
assert len(upcasts) == 3 # 3 transformation matrices
|
||||
# TODO: what did this fix?
|
||||
assert len(wino_schedule) <= 4 # 4 kernels
|
||||
# this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
|
||||
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
|
||||
|
||||
out.mean().backward()
|
||||
for si in x.grad.lazydata.schedule() + w.grad.lazydata.schedule():
|
||||
backward_schedule = x.grad.lazydata.schedule() + w.grad.lazydata.schedule()
|
||||
for si in backward_schedule:
|
||||
k = Linearizer(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
if len(k.bufs) < 20: continue # not a tile transform kernel
|
||||
# heuristic number to make sure that at least some upcasts but not too many upcasts are being done
|
||||
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 49
|
||||
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216
|
||||
assert len(backward_schedule) <= 13 # just the current number, but it could be better
|
||||
|
||||
def test_masked_upcast_many(self):
|
||||
layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4))
|
||||
|
||||
@@ -28,6 +28,7 @@ class TestWinograd(unittest.TestCase):
|
||||
l = Linearizer(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
assert len(l.sts) <= 256 # just the current value to prevent regression
|
||||
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
|
||||
for st in l.sts:
|
||||
assert len(st.views) <= 2, "too many views in winograd"
|
||||
|
||||
Reference in New Issue
Block a user