mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move up migrate + new gated fold (#7403)
* move up migrate + new gated fold [pr] * vcount for const ptr * move those rules there * fix openpilot
This commit is contained in:
@@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding, finalize
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
@@ -389,12 +389,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[-1].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
|
||||
|
||||
@@ -404,12 +402,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier))
|
||||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[-1].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.index(lidx+2))
|
||||
|
||||
@@ -449,7 +445,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def expander_rewrite(sink):
|
||||
sink = graph_rewrite(sink, sym + expander)
|
||||
return graph_rewrite(sink, sym + reducer)
|
||||
sink = graph_rewrite(sink, sym + reducer)
|
||||
return graph_rewrite(sink, sym + finalize)
|
||||
def float4_rewrite(sink): return graph_rewrite(sink, sym + expander + float4_folding)
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
@@ -660,8 +657,6 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
print(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
|
||||
|
||||
def gate_rewrite(sink): return graph_rewrite(sink, sym + expander + reducer)
|
||||
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
@@ -675,12 +670,12 @@ class TestIFUOps(unittest.TestCase):
|
||||
lbuf = UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
|
||||
store = UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (store,))
|
||||
sink = gate_rewrite(sink)
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
@@ -693,12 +688,12 @@ class TestIFUOps(unittest.TestCase):
|
||||
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
|
||||
stores = [UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
|
||||
sink = gate_rewrite(sink)
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
# this will be fixed with the merge gated stores bounty
|
||||
@unittest.expectedFailure
|
||||
@@ -709,12 +704,12 @@ class TestIFUOps(unittest.TestCase):
|
||||
gate = valid&(lidx.ne(2))
|
||||
stores = [UOp(UOps.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
|
||||
sink = gate_rewrite(sink)
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user