some rules to simplify max (#7258)

This commit is contained in:
George Hotz
2024-10-24 15:27:21 +07:00
committed by GitHub
parent a7be9dfd71
commit b56fab54ea

View File

@@ -961,6 +961,10 @@ def simplify_valid(valid:UOp) -> Optional[UOp]:
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None
def max_var_const(x:UOp, c1:UOp, c2:UOp):
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
symbolic = PatternMatcher([
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
@@ -987,6 +991,10 @@ symbolic = PatternMatcher([
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
# ** max fixups **
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
((UPat.var("x")+UPat.var("z")).max(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.max(y) + z),
((UPat.var("x")*UPat.cvar("c1")).max(UPat.var("x")*UPat.cvar("c2")), max_var_const),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
@@ -1055,8 +1063,8 @@ symbolic_flat = symbolic+PatternMatcher([
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
# for debug
syms = { BinaryOps.ADD: "+", BinaryOps.MUL: "*", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
renderer = PatternMatcher([
(UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),