diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 7ae53bf259..62f29dfffd 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -93,6 +93,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False (UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0 (UPat.var("x") ^ UPat.var("x"), lambda x: x.const_like(0)), # x^x -> 0 + (UPat.var("x") & 0, lambda x: x.const_like(0)), # x&0 -> 0 (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) # ** constant folding ** @@ -231,8 +232,9 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ # if its a plus we add the associative variation too ((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \ lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), - # ALU/variable min==max -> CONST (slow!) - (UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), + # ALU/variable min==max -> CONST + (UPat({Ops.CMPLT, Ops.CMPNE, Ops.IDIV, Ops.MOD, Ops.DEFINE_VAR, Ops.SPECIAL}, name="x"), + lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), (UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),