mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
dearg consts [run_process_replay] (#6324)
This commit is contained in:
@@ -284,8 +284,8 @@ constant_folder = PatternMatcher([
|
||||
# ** load/store folding **
|
||||
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
||||
# ** two stage add/mul folding **
|
||||
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
||||
((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*x.const(exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
||||
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+(c1+c2)),
|
||||
((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
||||
# *** rules from symbolic ***
|
||||
# ** lt **
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
@@ -293,7 +293,7 @@ constant_folder = PatternMatcher([
|
||||
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
|
||||
# mul add lt
|
||||
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
|
||||
lambda x,x2,c0,c1: x.lt(c1.arg//c0.arg) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax.arg and x2.vmin.arg >= 0 else None),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax.arg and x2.vmin.arg >= 0 else None),
|
||||
# generic lt folding (use div)
|
||||
(NOp.var('x').lt(NOp.cvar('c')), lambda x,c: newx.src[0].lt(newx.src[1]) if 0 < c.arg and dtypes.is_int(x.dtype) and \
|
||||
not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV else None),
|
||||
@@ -305,7 +305,7 @@ constant_folder = PatternMatcher([
|
||||
# mod folding
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# mul mod
|
||||
((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None),
|
||||
((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None),
|
||||
# (x%c)+(x//c)*c = x
|
||||
(NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x),
|
||||
# ** combine terms **
|
||||
@@ -313,11 +313,11 @@ constant_folder = PatternMatcher([
|
||||
(-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)),
|
||||
# (x+c0)*c1 -> x*c1+c0*c1. only for signed int, float have inf*0=nan issue
|
||||
((NOp.var("x") + NOp.cvar("c0")) * NOp.cvar("c1"), lambda x,c0,c1:
|
||||
x*c1+c0.arg*c1.arg if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
x*c1+c0*c1 if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
# (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
||||
(NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*(c0+c1)),
|
||||
# (x+x*c)-> x*(c+1)
|
||||
(NOp.var("x") + NOp.var("x") * NOp.cvar("c"), lambda x,c: x*(c.arg+1)),
|
||||
(NOp.var("x") + NOp.var("x") * NOp.cvar("c"), lambda x,c: x*(c+1)),
|
||||
# (x+x)-> x*2
|
||||
(NOp.var("x") + NOp.var("x"), lambda x: x*2),
|
||||
# (x*c0)+(y*c0) -> (x+y)*c0
|
||||
@@ -325,11 +325,11 @@ constant_folder = PatternMatcher([
|
||||
# (x*x2)/x2 -> x
|
||||
((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x),
|
||||
# (x//c0)//c1 -> x//(c0*c1)
|
||||
((NOp.var("x") // NOp.cvar("c0")) // NOp.cvar("c1"), lambda x,c0,c1: x//x.const(exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
|
||||
((NOp.var("x") // NOp.cvar("c0")) // NOp.cvar("c1"), lambda x,c0,c1: x//(c0*c1)),
|
||||
# (x/x1)/x2 -> x/(x1*x2)
|
||||
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
||||
# c0 + x < c1 -> x < c1 - c0
|
||||
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
||||
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)),
|
||||
# x!=0 -> (bool)x
|
||||
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
|
||||
Reference in New Issue
Block a user