diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 6b5e5ea224..88e61fba22 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -593,11 +593,11 @@ class TestLoadStoreFolder(unittest.TestCase): gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1") load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(UOps.SINK, None, tuple(load)) - sink = float4_rewrite(sink) - assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 + sink = full_graph_rewrite(sink) + assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 4 one_store = [x for x in sink.sparents if x.op is UOps.STORE][0] assert len(one_store.src) == 4 - assert str(one_store.src[3]) == str(UOp(UOps.IF, None, (gate,),)) # huh, why do i need str here? + assert_equiv_uops(one_store.src[3], UOp(UOps.IF, None, (gate, one_store.src[2]),)) def test_simple_store_dont_fold(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index d596596e6c..96b07266af 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -417,8 +417,6 @@ def create_gate(root:UOp) -> Optional[UOp]: if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: # NOTE: gate could already be wrapped in an IF, so only take IF's src[0] in that case. Will join IFs in delete_redundant_gates return UOp(u.op, u.dtype, u.src[:-1] + (UOp(UOps.IF, None, (gate if gate.op is not UOps.IF else gate.src[0], u.src[-1])),), u.arg) - if u.op is UOps.STORE and len(u.src) == 4 and u.src[-1].op not in {UOps.IF, UOps.EXPAND}: - return UOp(u.op, u.dtype, u.src[:-1] + (UOp(UOps.IF, None, (gate,)),), u.arg) return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg) return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret @@ -452,8 +450,8 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]: return None def update_gates(root:UOp) -> Optional[UOp]: - if len(root.src) < 4 or len(root.src[3].src) >= 2: return None - return UOp(UOps.STORE, root.dtype, root.src[:3] + (UOp(UOps.IF, None, (root.src[3].src[0], root.src[2])),), root.arg) + if len(root.src) < 4 or root.src[3].op is UOps.IF: return None + return UOp(UOps.STORE, root.dtype, root.src[:3] + (UOp(UOps.IF, None, (root.src[3], root.src[2])),), root.arg) reducer = PatternMatcher([ (NOp(UOps.REDUCE, name="root"), do_reduce),