diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 213696b8c9..45361ad380 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -472,27 +472,25 @@ sym = PatternMatcher([ # *** rules from symbolic *** # ** lt ** # c0*x 0 and c1.arg > 0 else None), + ((UPat.cvar("c0", vec=False)*UPat.var("x", dtypes.ints)).lt(UPat.cvar("c1", vec=False)), + lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None), # c0*x 0 else None), + ((UPat.var("x", dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)), + lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None), # mul add lt (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)), lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None), # generic lt folding - (UPat.var("x").lt(UPat.cvar("c", vec=False)), - lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None), + (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), # canonicalize a simplex with positive coefficients > 0 # not x < 1 -> X > 0 - (UPat.var("x").lt(1).ne(True), lambda x: newx.lt(1).ne(True) if dtypes.is_int(x.dtype) and (newx:=canonicalize_simplex(x)) is not None else None), + (UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # # div folding - (UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c: - newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None), + (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None), # ** mod ** # mod folding (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), @@ -506,7 +504,7 @@ sym = PatternMatcher([ (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y ((UPat.cvar("c0") + UPat.var("x")).lt(UPat.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0 # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((UPat.var("x") + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None), + ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), # x!=0 -> (bool)x (UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # bitwise noops