simpler pattern matcher rules [run_process_replay] (#5620)

This commit is contained in:
chenyu
2024-07-21 04:05:01 -04:00
committed by GitHub
parent 0f67ef4674
commit a823759dc5

View File

@@ -140,8 +140,8 @@ constant_folder = PatternMatcher([
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
# VECTORIZE/GEP
(UOp(UOps.GEP, src=(UOp(UOps.VECTORIZE).name("cast"),)).name("gep"), lambda gep, cast: cast.src[gep.arg]),
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j)
for j in range(i))), lambda x: x) for i in [2, 4, 8]],
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j) for j in range(i))), lambda x: x) \
for i in [2, 4, 8]],
# tensor core with a 0 input is acc
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
@@ -155,13 +155,10 @@ constant_folder = PatternMatcher([
# threefry
(UOp(UOps.ALU, dtype=dtypes.uint64, src=(UOp.var("x"), UOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
# arange loop folding (early)
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(BinaryOps.MUL,
UOp.cvar("mval"), UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None,0)), loop_collapse),
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(UnaryOps.NEG,
UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None, 0)),
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
((UOp.var("idx") + UOp.cvar("mval") * UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng")).lt(UOp.cvar("compval")).where(
UOp.cvar("multconst"), UOp.const(None, 0)), loop_collapse),
((UOp.var("idx") - UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng")).lt(UOp.cvar("compval")).where(
UOp.cvar("multconst"), UOp.const(None, 0)), lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
# sum collapse to mul (with possible GEP)
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
@@ -260,13 +257,13 @@ constant_folder = PatternMatcher([
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
# (x+x*c0)-> x*(c0+1)
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)),
# x!=0 -> (bool)x
(UOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
# bool != 1 -> not bool
(UOp.var("x", dtype=dtypes.bool).ne(1), lambda x: -x),
# TODO: can do the invert of this (flip alt/load) when we fix double ops
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("gate").where(UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),