Revert "improve full_graph_rewrite matchers for speed (#7431)" (#7434)

This reverts commit 996152d2de.
This commit is contained in:
George Hotz
2024-10-31 15:16:47 +07:00
committed by GitHub
parent 996152d2de
commit 2e3048fc57
3 changed files with 50 additions and 47 deletions

View File

@@ -4,9 +4,8 @@ from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, load_store_indexing, sym, float4_folding, migrate_indexing
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding, finalize, migrate_indexing
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.shape.shapetracker import ShapeTracker, View
@@ -446,10 +445,12 @@ class TestUOpGraph(unittest.TestCase):
def expander_rewrite(sink):
sink = graph_rewrite(sink, sym + expander)
return graph_rewrite(sink, sym + load_store_indexing)
sink = graph_rewrite(sink, sym + reducer)
return graph_rewrite(sink, sym + finalize)
def float4_rewrite(sink):
sink = graph_rewrite(sink, sym + migrate_indexing)
return graph_rewrite(sink, sym + expander + float4_folding)
sink = graph_rewrite(sink, sym + expander + float4_folding)
return graph_rewrite(sink, sym + finalize)
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
@@ -617,11 +618,11 @@ class TestLoadStoreFolder(unittest.TestCase):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)).sink()
sink = full_graph_rewrite(sink, Renderer())
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
self.assertEqual(single_load.src[1].op, UOps.CONST)
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
@@ -636,7 +637,7 @@ class TestLoadStoreFolder(unittest.TestCase):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = full_graph_rewrite(sink, Renderer())
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
def test_simple_store_fold_gate(self):
@@ -644,7 +645,7 @@ class TestLoadStoreFolder(unittest.TestCase):
gate = UOp.variable("g1", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = full_graph_rewrite(sink, Renderer())
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
assert len(one_store.src) == 3
@@ -656,7 +657,8 @@ class TestLoadStoreFolder(unittest.TestCase):
gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = full_graph_rewrite(sink, Renderer())
sink = float4_rewrite(sink)
print(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
class TestIFUOps(unittest.TestCase):