comment out fold_where_closure (#14316)

This commit is contained in:
chenyu
2026-01-24 10:15:42 -05:00
committed by GitHub
parent d74587f16d
commit cb69b7b2b2
2 changed files with 11 additions and 9 deletions

View File

@@ -769,6 +769,7 @@ class TestSymbolic(unittest.TestCase):
# (a if ((s<5)&(s<6)) else b) -> (a if (s<5) else b)
self.helper_test_variable(expr, 0, 3, "(s<5).where(a, b)")
@unittest.expectedFailure
def test_where_closure_folding(self):
# cond.where(t, f) where f contains cond.where(a, b) should fold the inner where to b in false branch
x = Variable("x", 0, 10)

View File

@@ -368,13 +368,14 @@ def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None:
if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD): return None
return cond.where(uop_given_valid(cond, x, try_simplex=False), i)
def fold_where_closure(cond:UOp, t:UOp, f:UOp) -> UOp|None:
"""In cond.where(t, f), fold nested cond.where(a, b) -> a in t, -> b in f"""
def is_valid_where(u:UOp) -> bool: return u.op is Ops.WHERE and u.src[0] is cond and Invalid not in (u.src[1].arg, u.src[2].arg)
t_subs, f_subs = {u: u.src[1] for u in t.toposort() if is_valid_where(u)}, {u: u.src[2] for u in f.toposort() if is_valid_where(u)}
if not t_subs and not f_subs: return None
new_t, new_f = t.substitute(t_subs).simplify() if t_subs else t, f.substitute(f_subs).simplify() if f_subs else f
return None if new_t is t and new_f is f else cond.where(new_t, new_f)
# TODO: this is O(number of WHERE * number of node)
# def fold_where_closure(cond:UOp, t:UOp, f:UOp) -> UOp|None:
# """In cond.where(t, f), fold nested cond.where(a, b) -> a in t, -> b in f"""
# def is_valid_where(u:UOp) -> bool: return u.op is Ops.WHERE and u.src[0] is cond and Invalid not in (u.src[1].arg, u.src[2].arg)
# t_subs, f_subs = {u: u.src[1] for u in t.toposort() if is_valid_where(u)}, {u: u.src[2] for u in f.toposort() if is_valid_where(u)}
# if not t_subs and not f_subs: return None
# new_t, new_f = t.substitute(t_subs).simplify() if t_subs else t, f.substitute(f_subs).simplify() if f_subs else f
# return None if new_t is t and new_f is f else cond.where(new_t, new_f)
pm_simplify_valid = PatternMatcher([
# simplify valid
@@ -392,8 +393,8 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
# x!=0 -> (bool)x
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** where **
# fold nested where with same condition: in cond.where(t,f), cond.where(a,b)->a in t, ->b in f
(UPat.var("cond").where(UPat.var("t"), UPat.var("f")), fold_where_closure),
# # fold nested where with same condition: in cond.where(t,f), cond.where(a,b)->a in t, ->b in f
# (UPat.var("cond").where(UPat.var("t"), UPat.var("f")), fold_where_closure),
# push cast to branches
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
# ** pow **