clean propagate_invalid more [pr] (#13347)

This commit is contained in:
chenyu
2025-11-19 09:47:50 -05:00
committed by GitHub
parent 0c9fbf87e1
commit 79055ddb8b

View File

@@ -24,14 +24,15 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
propagate_invalid = PatternMatcher([
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
# propagate invalid, push it past children
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if cast.dtype is not dtypes.index else None),
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype)),
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i))
for op in GroupOp.Binary-GroupOp.Comparison),
# TODO: when can this happen? and is it always safe to just drop invalid?
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison),
# invalid + y -> y same for other ops
# invalid + y -> invalid same for other ops
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison),
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),