diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 71544958e0..389d68c492 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -430,9 +430,6 @@ sym = simple_pm+PatternMatcher([ name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), # GEP/CAST const rules (UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), - # a conditional with the same results either way is a noop, also fold const conditionals - (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), - (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # ** self folding ** # cast NOOP (NOTE: it's str to deal with PtrDType) (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), @@ -451,12 +448,6 @@ sym = simple_pm+PatternMatcher([ # ** mod ** # mod folding (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), - # ** combine terms (opinionated) ** - (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 - ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) - (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y - # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), # x!=0 -> (bool)x (UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # TODO: can do the invert of this (flip alt/load) when we fix double ops @@ -480,12 +471,6 @@ sym = simple_pm+PatternMatcher([ (UPat(UOps.SINK, name="root"), lambda root: UOp(UOps.SINK, root.dtype, tuple(flatten(x.src if x.op in {UOps.SINK, UOps.EXPAND} else (x,) for x in root.src)), root.arg) if any(x.op in {UOps.SINK, UOps.EXPAND} for x in root.src) else None), - # ** move add consts to end (NOTE: this is still happening before constant folding) ** - (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op not in (UOps.CONST, UOps.VCONST) else None), - (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), - # ** move mul consts to end (NOTE: this is still happening before constant folding) ** - (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None), - (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), ]) # *** uop expander *** diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3bf8b6a414..9a0f1a0090 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -727,6 +727,12 @@ simple_pm = PatternMatcher([ # ** combine terms ** (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) + # ** combine terms (opinionated) ** + (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 + ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) + (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y + # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue + ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False (UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints) @@ -734,6 +740,9 @@ simple_pm = PatternMatcher([ # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), + # a conditional with the same results either way is a noop, also fold const conditionals + (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), + (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # ** constant folding ** (UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))), @@ -763,4 +772,10 @@ simple_pm = PatternMatcher([ # mul add lt (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)), lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None), + # ** move add consts to end (NOTE: this is still happening before constant folding) ** + (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op not in (UOps.CONST, UOps.VCONST) else None), + (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), + # ** move mul consts to end (NOTE: this is still happening before constant folding) ** + (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None), + (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), ])