diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8ce2eb0869..d22c61e8b5 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -961,6 +961,10 @@ def simplify_valid(valid:UOp) -> Optional[UOp]: if ret[-1] is not stmt: something_changed = True return functools.reduce(operator.and_, ret) if something_changed else None +def max_var_const(x:UOp, c1:UOp, c2:UOp): + if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2 + if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1 + symbolic = PatternMatcher([ # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y), @@ -987,6 +991,10 @@ symbolic = PatternMatcher([ (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) + # ** max fixups ** + (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c + ((UPat.var("x")+UPat.var("z")).max(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.max(y) + z), + ((UPat.var("x")*UPat.cvar("c1")).max(UPat.var("x")*UPat.cvar("c2")), max_var_const), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False (UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints) @@ -1055,8 +1063,8 @@ symbolic_flat = symbolic+PatternMatcher([ _substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))]) # for debug -syms = { BinaryOps.ADD: "+", BinaryOps.MUL: "*", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>", - BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"} +syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>", + BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"} renderer = PatternMatcher([ (UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])), (UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),