mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 10:02:00 -05:00
new style load/store folder (#5784)
* remove old index reorder * new style folder * works better * dedup * one failure * this is fine now... * expander_rewrite * images broken, but all else should work * cleanups * make tests work with old * fix images * cleanups + bugfix * minor fixes * fix gated store folding * flip gate_creator and expander * fix gated store * remove unneeded rules * lines getting close * line count good
This commit is contained in:
@@ -5,7 +5,7 @@ from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
|
||||
from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, constant_folder
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
@@ -82,7 +82,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1))
|
||||
outs = [2+a, 2+a+d+3+b+c+4] #, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, constant_folder)
|
||||
print(sink)
|
||||
@@ -233,12 +233,12 @@ class TestUOpGraph(TestUOps):
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3)))
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld0+ld1))])
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2))
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assert_equiv_uops(ld1, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
|
||||
@@ -248,12 +248,12 @@ class TestUOpGraph(TestUOps):
|
||||
barrier = UOp(UOps.BARRIER, None, (st, ))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3), barrier))
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld0+ld1))])
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2))
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assert_equiv_uops(ld1, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
|
||||
@@ -289,12 +289,8 @@ class TestUOpGraph(TestUOps):
|
||||
# ranges are closed in the right order
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
def expander_rewrite(sink):
|
||||
together = PatternMatcher(expander.patterns + constant_folder.patterns)
|
||||
return graph_rewrite(sink, together)
|
||||
#out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
|
||||
#out.linearize()
|
||||
#return out.uops[-1]
|
||||
def expander_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
|
||||
def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding)
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
def test_expand_add_broadcast(self):
|
||||
@@ -424,5 +420,67 @@ class TestExpander(unittest.TestCase):
|
||||
sink = expander_rewrite(sink)
|
||||
print(sink)
|
||||
|
||||
class TestLoadStoreFolder(unittest.TestCase):
|
||||
def test_simple_load_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
|
||||
def test_two_load_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,8),))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
|
||||
|
||||
def test_simple_load_fold_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate, UOp.const(dtypes.float, i))) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
self.assertListEqual([src.arg for src in single_load.src[3].src], [0.0, 1.0, 2.0, 3.0])
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate if i == 0 else gate2, UOp.const(dtypes.float, i))) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
|
||||
|
||||
def test_simple_store_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
|
||||
assert len(one_store.src) == 4
|
||||
assert str(one_store.src[3]) == str(gate) # huh, why do i need str here?
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
print(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user