UOp.const(x.dtype, y) -> x.const(y) [run_process_replay] (#5642)

This commit is contained in:
chenyu
2024-07-22 17:09:40 -04:00
committed by GitHub
parent 97b116bb1d
commit 24505199fb
4 changed files with 20 additions and 21 deletions

View File

@@ -10,7 +10,7 @@ simple_pm = PatternMatcher([
(UOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(UOp.cvar('x') + UOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(UOp.cvar('x') * UOp.cvar('y') * UOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x + UOp.const(x.dtype, c1.arg+c2.arg)),
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
])
class TestGraphRewrite(unittest.TestCase):

View File

@@ -20,7 +20,7 @@ def image_contract_load(buf, idx, idy, id4, ls_allow_any_len):
else: extra = ls_allow_any_len.src[2:] # NOTE: image load shouldn't have barrier and this shouldn't matter
vec_load = UOp(UOps.LOAD, ls_allow_any_len.dtype.vec(4), (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy))) + extra)
return functools.reduce(lambda ret, i: UOp.alu(TernaryOps.WHERE, id4.ne(i), ret, UOp(UOps.GEP, ls_allow_any_len.dtype, (vec_load,), i)), range(4),
UOp.const(ls_allow_any_len.dtype, float('nan')))
ls_allow_any_len.const(float('nan')))
def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var):
new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), (ex.arg[0][0],))
@@ -190,16 +190,16 @@ constant_folder = PatternMatcher([
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s'))), lambda c,s: -s if -(s.src[1].arg-1) >= c.arg else None),
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.src[1].arg-1+c2.arg) >= c.arg else None),
# const rules
(UOp(UOps.GEP, src=(UOp.cvar("c"),)).name("root"), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
(UOp(UOps.GEP, src=(UOp.cvar("c"),)).name("root"), lambda root, c: root.const(c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
(UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC).name("acc"), UOp.var("acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
(UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)), UOp.var("x"))), lambda x: x),
(UOp(UOps.PHI, src=(UOp.cvar(), UOp.var("x"))), lambda x: x),
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
(UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)).name("root"), lambda root: UOp.cast(root.src[0], root.dtype)),
(UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: UOp.const(root.dtype, x.arg)),
(UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: root.const(x.arg)),
# max -2147483648
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
# bool < False is always false, True < bool is always false
@@ -209,39 +209,39 @@ constant_folder = PatternMatcher([
(UOp.var().where(UOp.var("val"), UOp.var("val")), lambda val: val),
(UOp.cvar('gate').where(UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
# ** self folding **
(-(-UOp.var('x')), lambda x: x), # -(-x) -> x
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
(UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
(UOp.var('x') // UOp.var('x'), lambda x: x.const(1)), # x//x -> 1
(UOp.var('x') // 1, lambda x: x), # x//1 -> x
(UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
(UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
(UOp.var('x') / UOp.var('x'), lambda x: x.const(1)), # x/x -> 1
(UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
# ** zero folding **
#x*0 -> 0 or 0*x -> 0
#if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN
(UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
(UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
(UOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
(UOp.var('x') - UOp.var('x'), lambda x: x.const(0)), # x-x -> 0
# ** load/store folding **
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
# ** two stage add/sub folding **
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
# *** rules from symbolic ***
# mod divides
((UOp.cvar('c')*UOp.var('x')) % UOp.cvar('c'), lambda x,c: x.const(0)),
(((UOp.cvar('c')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c'), lambda x,c,x2: x2%c),
# two stage mul, (x*c1)*c2 = x*(c1*c2)
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*x.const(exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
# -(x+y) -> -x + -y
#(-(UOp.var("x") + UOp.var("y")), lambda x,y: (-x)+(-y)),
# x%1 -> 0
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
(UOp.var("x") % UOp.const(None, 1), lambda x: x.const(0)),
# (x*c0)+(x*c1) -> x*(c0+c1)
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
# (x*c0)+(y*c0) -> (x+y)*c0
@@ -252,12 +252,11 @@ constant_folder = PatternMatcher([
# (x*x2)/x2 -> x
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
# (x//c0)//c1 -> x//(c0*c1)
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//x.const(exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
# (x/x1)/x2 -> x/(x1*x2)
((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
# c0 + x < c1 -> x < c1 - c0
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
# (x+x*c0)-> x*(c0+1)
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)),
# x!=0 -> (bool)x
@@ -273,7 +272,7 @@ constant_folder = PatternMatcher([
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
# NEG/CMPLT -> CMPLT
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(c.const(-c.arg), x)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UOp(UOps.CAST).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(UOp(UOps.VECTORIZE).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),

View File

@@ -45,7 +45,7 @@ class UOp:
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
# *** uop syntactic sugar
def ufix(self, x): return UOp.const(self.dtype, x) if not isinstance(x, UOp) else x
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UOp(UOps.BITCAST, dtype, (self,))
def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)

View File

@@ -262,7 +262,7 @@ ptx_matcher = PatternMatcher([
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
lambda root, alu, const: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
const.const(root.src[0].dtype.itemsize)*const)+root.src[2:])),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(UOps.CONST, name="const"))),
lambda root, const: UOp(root.op, root.dtype,