new UOp style patterns [run_process_replay] (#5444)

* express permute srcs in uop

* loop folding / sum collapse pats -> uop style

* UNMUL, const, phi on DEFINE_ACC pats -> uop style

* fix: cvar not const

* DEFINE_ACC w/o inputs, VECTORIZE-PHI-GEP pats -> uop style

* fix VECTORIZE-PHI-GEP pat

* contractor, reducer, float4 pats -> uop style

* arange folding .where

* one more

* revert permute expression in UOp
This commit is contained in:
Carson Powers
2024-07-13 20:21:08 -04:00
committed by GitHub
parent 942c58be90
commit ef578b4de8

View File

@@ -212,14 +212,14 @@ def cast_reduce(cst):
contractor = PatternMatcher([
# contracts
(UPat(UOps.CONTRACT, name="root"), replace_contract),
(UOp(UOps.CONTRACT).name("root"), replace_contract),
# VECTORIZE after REDUCEs -> one REDUCE (breaks TestConv.test_two_binops_no_rerun)
(UPat(UOps.VECTORIZE, name="cst", src=UPat(UOps.REDUCE)), cast_reduce),
])
reducer = PatternMatcher([
(UPat(UOps.REDUCE, name="root"), replace_reduce),
(UPat(UOps.WMMA, name="wmma"), expand_wmma),
(UOp(UOps.REDUCE).name("root"), replace_reduce),
(UOp(UOps.WMMA).name("wmma"), expand_wmma),
# image indexing. TODO: why can't this just go after the float stuff?
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
])
@@ -279,7 +279,7 @@ float4_folding = PatternMatcher([
(UOp(UOps.STORE, src=(UOp.var("buf"),
UOp(UOps.EXPAND).name("ex"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
# no ALU on float4 (float4 constructor doesn't work in METAL/GPU)
(UPat(UOps.ALU, name="alu"), no_float4_alu),
(UOp(UOps.ALU).name("alu"), no_float4_alu),
])
# ***** transcendental *****
@@ -352,15 +352,13 @@ constant_folder = PatternMatcher([
# threefry
(UOp(UOps.ALU, dtype=dtypes.uint64, src=(UOp.var("x"), UOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
# arange loop folding (early)
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(BinaryOps.MUL,
UOp.cvar("mval"), UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None,0)), loop_collapse),
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(UnaryOps.NEG,
UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None, 0)),
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
# sum collapse to mul (with possible GEP)
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
@@ -387,16 +385,16 @@ constant_folder = PatternMatcher([
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s'))), lambda c,s: -s if -(s.src[1].arg-1) >= c.arg else None),
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.src[1].arg-1+c2.arg) >= c.arg else None),
# const rules
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
(UOp(UOps.GEP, src=(UOp.cvar("c"),)).name("root"), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
(UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC).name("acc"), UOp.var("acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
(UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)), UOp.var("x"))), lambda x: x),
(UOp(UOps.PHI, src=(UOp.cvar(), UOp.var("x"))), lambda x: x),
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
(UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)).name("root"), lambda root: UOp.cast(root.src[0], root.dtype)),
(UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: UOp.const(root.dtype, x.arg)),
# max -2147483648
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
# bool < False is always false, True < bool is always false
@@ -466,17 +464,15 @@ constant_folder = PatternMatcher([
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
(UPat(UOps.VECTORIZE, name="root", src=tuple(
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
(UPat(UOps.VECTORIZE, name="root", src=tuple(
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
# NEG/CMPLT -> CMPLT
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(UPat(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(UOp(UOps.CAST).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(UOp(UOps.VECTORIZE).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
# fold gated LOAD/STORE
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),
@@ -486,7 +482,7 @@ constant_folder = PatternMatcher([
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.bool, True)), UOp.store),
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
# remove NOOPs from SINK
(UPat(UOps.SINK, name="root"),
(UOp(UOps.SINK).name("root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
])