mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
This reverts commit 996152d2de.
This commit is contained in:
@@ -4,9 +4,8 @@ from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, load_store_indexing, sym, float4_folding, migrate_indexing
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding, finalize, migrate_indexing
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
@@ -446,10 +445,12 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def expander_rewrite(sink):
|
||||
sink = graph_rewrite(sink, sym + expander)
|
||||
return graph_rewrite(sink, sym + load_store_indexing)
|
||||
sink = graph_rewrite(sink, sym + reducer)
|
||||
return graph_rewrite(sink, sym + finalize)
|
||||
def float4_rewrite(sink):
|
||||
sink = graph_rewrite(sink, sym + migrate_indexing)
|
||||
return graph_rewrite(sink, sym + expander + float4_folding)
|
||||
sink = graph_rewrite(sink, sym + expander + float4_folding)
|
||||
return graph_rewrite(sink, sym + finalize)
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
def test_expand_add_broadcast(self):
|
||||
@@ -617,11 +618,11 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)).sink()
|
||||
sink = full_graph_rewrite(sink, Renderer())
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
|
||||
self.assertEqual(single_load.src[1].op, UOps.CONST)
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
@@ -636,7 +637,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = full_graph_rewrite(sink, Renderer())
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
@@ -644,7 +645,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = full_graph_rewrite(sink, Renderer())
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
|
||||
assert len(one_store.src) == 3
|
||||
@@ -656,7 +657,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = full_graph_rewrite(sink, Renderer())
|
||||
sink = float4_rewrite(sink)
|
||||
print(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
|
||||
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user