From d609206a4ab5af3b3b6795ea32fb2ce0367ed8fc Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 1 Aug 2024 11:32:08 -0400 Subject: [PATCH] move UOp patterns around [run_process_replay] (#5857) group lt / div / mod together and minor cleanups --- tinygrad/codegen/uopgraph.py | 55 +++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 59b73f6da0..c782685eb5 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -204,51 +204,56 @@ constant_folder = PatternMatcher([ (NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1 (NOp.var('x') / NOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c) # ** zero folding ** - #x*0 -> 0 or 0*x -> 0 - #if x is nan or inf it should render the nan value. + # x*0 -> 0 or 0*x -> 0 + # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN (NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), - (NOp.var('x') - NOp.var('x'), lambda x: x.const(0)), # x-x -> 0 + # x-x -> 0 + (NOp.var('x') - NOp.var('x'), lambda x: x.const(0)), (UPat(op=UOps.ALU, name='x'), lambda x: x.const(x.vmin.arg) if x.op is not UOps.CONST and x.vmin.arg == x.vmax.arg else None), - # c0*x 0 and c1.arg > 0 else None), - # mul -> add -> lt - (((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')), - lambda x,x2,c0,c1: x.lt(c1.arg//c0.arg) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax.arg and x2.vmin.arg >= 0 else None), # ** load/store folding ** (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), # ** two stage add/mul folding ** ((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))), ((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*x.const(exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))), # *** rules from symbolic *** + # ** lt ** + # c0*x 0 and c1.arg > 0 else None), + # mul add lt + (((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')), + lambda x,x2,c0,c1: x.lt(c1.arg//c0.arg) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax.arg and x2.vmin.arg >= 0 else None), + # neg lt -> lt + (NOp.lt(-NOp.var('x'), NOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(c.const(-c.arg), x)), + # ** div ** # div folding (NOp.var('x') // NOp.cvar('c'), lambda x,c: x.const(x.vmin.arg//c.arg) if c.arg > 0 and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None), (NOp.var('x') // NOp.cvar('c'), lambda x,c: d if c.arg > 0 and (d:=x.divides(c.arg)) is not None else None), - # mod folding - (NOp.var('x') % NOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None), - # mod reduction - (NOp.var('x') % NOp.cvar('c'), lambda x,c: (x-(x.vmin.arg//c.arg)*c.arg)%c if 0 < c.arg <= x.vmin.arg else None), - # mul -> mod - ((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1:\ - x*(c0.arg%c1.arg)%c1 if c0.arg >= c1.arg > 0 else (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg==0 else None), - # mul -> add -> mod - (((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')) % NOp.cvar('c1'), - lambda x,x2,c0,c1: x2%c1 if (r:=c0.arg%c1.arg) == 0 else (x*r+x2)%c1 if c0.arg >= c1.arg > 0 else None), - # mul -> add -> div + # mul div + ((NOp.var("x") * NOp.cvar("c0")) // NOp.cvar("c1"), + lambda x,c0,c1: x*(c0.arg//gcd)//(c1.arg//gcd) if c1.arg!=0 and (gcd:=math.gcd(c0.arg,c1.arg))> 1 else None), + # mul add div (((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')) // NOp.cvar('c1'), lambda x,x2,c0,c1:\ x*(c0.arg//g)//(c1.arg//g) if c0.arg > 0 and c1.arg > 0 and (g:=math.gcd(c0.arg,c1.arg)) > 1 and g > x2.vmax.arg and x2.vmin.arg >= 0 else None), + # ** mod ** + # mod folding and mod reduction + (NOp.var('x') % NOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else \ + (x-(x.vmin.arg//c.arg)*c.arg)%c if 0 < c.arg <= x.vmin.arg else None), + # mul mod + ((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1:\ + x*(c0.arg%c1.arg)%c1 if 0 < c1.arg <= c0.arg else (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None), + # mul add mod + (((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')) % NOp.cvar('c1'), + lambda x,x2,c0,c1: x2%c1 if (r:=c0.arg%c1.arg) == 0 else (x*r+x2)%c1 if 0 < c1.arg <= c0.arg else None), # mod mod - ((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c0 if 0 < c0.arg < c1.arg else x % c1 if c0.arg % c1.arg == 0 else None), + ((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c1 if c0.arg % c1.arg == 0 else None), # -(x+y) -> -x + -y #(-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # (x*c0)+(x*c1) -> x*(c0+c1) (NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])), # (x*c0)+(y*c0) -> (x+y)*c0 #((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)), - # mul div - ((NOp.var("x") * NOp.cvar("c0")) // NOp.cvar("c1"), - lambda x,c0,c1: x*(c0.arg//gcd)//(c1.arg//gcd) if c1.arg!=0 and (gcd:=math.gcd(c0.arg,c1.arg))> 1 else None), # (x*x2)/x2 -> x ((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x), # (x//c0)//c1 -> x//(c0*c1) @@ -271,8 +276,6 @@ constant_folder = PatternMatcher([ lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))), (NOp(UOps.VECTORIZE, src=tuple(NOp(UOps.PHI, src=(NOp(UOps.GEP, src=(NOp.var("val"),), arg=i), NOp.var(f"v{i}"))) for i in range(2)), name="root"), lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))), - # NEG/CMPLT -> CMPLT - (NOp.lt(-NOp.var('x'), NOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(c.const(-c.arg), x)), # cast NOOP (NOTE: it's str to deal with PtrDType) (NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), (NOp(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),