mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
some rules to simplify max (#7258)
This commit is contained in:
@@ -961,6 +961,10 @@ def simplify_valid(valid:UOp) -> Optional[UOp]:
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return functools.reduce(operator.and_, ret) if something_changed else None
|
||||
|
||||
def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
||||
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
|
||||
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
|
||||
|
||||
symbolic = PatternMatcher([
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
@@ -987,6 +991,10 @@ symbolic = PatternMatcher([
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
# ** max fixups **
|
||||
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
||||
((UPat.var("x")+UPat.var("z")).max(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.max(y) + z),
|
||||
((UPat.var("x")*UPat.cvar("c1")).max(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
|
||||
@@ -1055,8 +1063,8 @@ symbolic_flat = symbolic+PatternMatcher([
|
||||
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
syms = { BinaryOps.ADD: "+", BinaryOps.MUL: "*", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
|
||||
BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
|
||||
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
|
||||
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
|
||||
renderer = PatternMatcher([
|
||||
(UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
|
||||
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
|
||||
Reference in New Issue
Block a user