mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean propagate_invalid more [pr] (#13347)
This commit is contained in:
@@ -24,14 +24,15 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
|
||||
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
|
||||
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
|
||||
|
||||
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
|
||||
propagate_invalid = PatternMatcher([
|
||||
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
|
||||
# propagate invalid, push it past children
|
||||
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if cast.dtype is not dtypes.index else None),
|
||||
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype)),
|
||||
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i))
|
||||
for op in GroupOp.Binary-GroupOp.Comparison),
|
||||
# TODO: when can this happen? and is it always safe to just drop invalid?
|
||||
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison),
|
||||
# invalid + y -> y same for other ops
|
||||
# invalid + y -> invalid same for other ops
|
||||
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison),
|
||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||
|
||||
Reference in New Issue
Block a user