new style load/store folder (#5784)

* remove old index reorder

* new style folder

* works better

* dedup

* one failure

* this is fine now...

* expander_rewrite

* images broken, but all else should work

* cleanups

* make tests work with old

* fix images

* cleanups + bugfix

* minor fixes

* fix gated store folding

* flip gate_creator and expander

* fix gated store

* remove unneeded rules

* lines getting close

* line count good
This commit is contained in:
George Hotz
2024-07-30 13:17:20 -07:00
committed by GitHub
parent e8a42b945c
commit 17a2f74412
4 changed files with 172 additions and 111 deletions

View File

@@ -5,7 +5,7 @@ from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, constant_folder
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding
simple_pm = PatternMatcher([
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
@@ -82,7 +82,7 @@ class TestGraphRewrite(unittest.TestCase):
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1))
outs = [2+a, 2+a+d+3+b+c+4] #, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
for out in outs:
sink = graph_rewrite(out, constant_folder)
print(sink)
@@ -233,12 +233,12 @@ class TestUOpGraph(TestUOps):
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3)))
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld0+ld1))])
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2))
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assert_equiv_uops(ld1, UOp.load(glbl2, idx, dtype=dtypes.int))
self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
@@ -248,12 +248,12 @@ class TestUOpGraph(TestUOps):
barrier = UOp(UOps.BARRIER, None, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3), barrier))
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld0+ld1))])
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2))
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assert_equiv_uops(ld1, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
def test_fold_gated_store(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
@@ -289,12 +289,8 @@ class TestUOpGraph(TestUOps):
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
def expander_rewrite(sink):
together = PatternMatcher(expander.patterns + constant_folder.patterns)
return graph_rewrite(sink, together)
#out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
#out.linearize()
#return out.uops[-1]
def expander_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding)
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
@@ -424,5 +420,67 @@ class TestExpander(unittest.TestCase):
sink = expander_rewrite(sink)
print(sink)
class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
def test_two_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,8),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
def test_simple_load_fold_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate, UOp.const(dtypes.float, i))) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertListEqual([src.arg for src in single_load.src[3].src], [0.0, 1.0, 2.0, 3.0])
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate if i == 0 else gate2, UOp.const(dtypes.float, i))) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
def test_simple_store_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
def test_simple_store_fold_gate(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
assert len(one_store.src) == 4
assert str(one_store.src[3]) == str(gate) # huh, why do i need str here?
def test_simple_store_dont_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
print(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
if __name__ == '__main__':
unittest.main(verbosity=2)