From cb69b7b2b28efe6f7bb553770b4e270df789bf5e Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 24 Jan 2026 10:15:42 -0500 Subject: [PATCH] comment out fold_where_closure (#14316) --- test/unit/test_uop_symbolic.py | 1 + tinygrad/uop/symbolic.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index e911ef8909..4968c81b87 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -769,6 +769,7 @@ class TestSymbolic(unittest.TestCase): # (a if ((s<5)&(s<6)) else b) -> (a if (s<5) else b) self.helper_test_variable(expr, 0, 3, "(s<5).where(a, b)") + @unittest.expectedFailure def test_where_closure_folding(self): # cond.where(t, f) where f contains cond.where(a, b) should fold the inner where to b in false branch x = Variable("x", 0, 10) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 16126724a1..136def2d86 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -368,13 +368,14 @@ def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None: if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD): return None return cond.where(uop_given_valid(cond, x, try_simplex=False), i) -def fold_where_closure(cond:UOp, t:UOp, f:UOp) -> UOp|None: - """In cond.where(t, f), fold nested cond.where(a, b) -> a in t, -> b in f""" - def is_valid_where(u:UOp) -> bool: return u.op is Ops.WHERE and u.src[0] is cond and Invalid not in (u.src[1].arg, u.src[2].arg) - t_subs, f_subs = {u: u.src[1] for u in t.toposort() if is_valid_where(u)}, {u: u.src[2] for u in f.toposort() if is_valid_where(u)} - if not t_subs and not f_subs: return None - new_t, new_f = t.substitute(t_subs).simplify() if t_subs else t, f.substitute(f_subs).simplify() if f_subs else f - return None if new_t is t and new_f is f else cond.where(new_t, new_f) +# TODO: this is O(number of WHERE * number of node) +# def fold_where_closure(cond:UOp, t:UOp, f:UOp) -> UOp|None: +# """In cond.where(t, f), fold nested cond.where(a, b) -> a in t, -> b in f""" +# def is_valid_where(u:UOp) -> bool: return u.op is Ops.WHERE and u.src[0] is cond and Invalid not in (u.src[1].arg, u.src[2].arg) +# t_subs, f_subs = {u: u.src[1] for u in t.toposort() if is_valid_where(u)}, {u: u.src[2] for u in f.toposort() if is_valid_where(u)} +# if not t_subs and not f_subs: return None +# new_t, new_f = t.substitute(t_subs).simplify() if t_subs else t, f.substitute(f_subs).simplify() if f_subs else f +# return None if new_t is t and new_f is f else cond.where(new_t, new_f) pm_simplify_valid = PatternMatcher([ # simplify valid @@ -392,8 +393,8 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([ # x!=0 -> (bool)x (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # ** where ** - # fold nested where with same condition: in cond.where(t,f), cond.where(a,b)->a in t, ->b in f - (UPat.var("cond").where(UPat.var("t"), UPat.var("f")), fold_where_closure), + # # fold nested where with same condition: in cond.where(t,f), cond.where(a,b)->a in t, ->b in f + # (UPat.var("cond").where(UPat.var("t"), UPat.var("f")), fold_where_closure), # push cast to branches (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))), # ** pow **