mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
late gate creation for STORE [run_process_replay] (#6373)
This commit is contained in:
@@ -593,11 +593,11 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||
sink = full_graph_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 4
|
||||
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
|
||||
assert len(one_store.src) == 4
|
||||
assert str(one_store.src[3]) == str(UOp(UOps.IF, None, (gate,),)) # huh, why do i need str here?
|
||||
assert_equiv_uops(one_store.src[3], UOp(UOps.IF, None, (gate, one_store.src[2]),))
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
|
||||
@@ -417,8 +417,6 @@ def create_gate(root:UOp) -> Optional[UOp]:
|
||||
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
|
||||
# NOTE: gate could already be wrapped in an IF, so only take IF's src[0] in that case. Will join IFs in delete_redundant_gates
|
||||
return UOp(u.op, u.dtype, u.src[:-1] + (UOp(UOps.IF, None, (gate if gate.op is not UOps.IF else gate.src[0], u.src[-1])),), u.arg)
|
||||
if u.op is UOps.STORE and len(u.src) == 4 and u.src[-1].op not in {UOps.IF, UOps.EXPAND}:
|
||||
return UOp(u.op, u.dtype, u.src[:-1] + (UOp(UOps.IF, None, (gate,)),), u.arg)
|
||||
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret
|
||||
|
||||
@@ -452,8 +450,8 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
return None
|
||||
|
||||
def update_gates(root:UOp) -> Optional[UOp]:
|
||||
if len(root.src) < 4 or len(root.src[3].src) >= 2: return None
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:3] + (UOp(UOps.IF, None, (root.src[3].src[0], root.src[2])),), root.arg)
|
||||
if len(root.src) < 4 or root.src[3].op is UOps.IF: return None
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:3] + (UOp(UOps.IF, None, (root.src[3], root.src[2])),), root.arg)
|
||||
|
||||
reducer = PatternMatcher([
|
||||
(NOp(UOps.REDUCE, name="root"), do_reduce),
|
||||
|
||||
Reference in New Issue
Block a user