mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
new UOp style patterns [run_process_replay] (#5444)
* express permute srcs in uop * loop folding / sum collapse pats -> uop style * UNMUL, const, phi on DEFINE_ACC pats -> uop style * fix: cvar not const * DEFINE_ACC w/o inputs, VECTORIZE-PHI-GEP pats -> uop style * fix VECTORIZE-PHI-GEP pat * contractor, reducer, float4 pats -> uop style * arange folding .where * one more * revert permute expression in UOp
This commit is contained in:
@@ -212,14 +212,14 @@ def cast_reduce(cst):
|
||||
|
||||
contractor = PatternMatcher([
|
||||
# contracts
|
||||
(UPat(UOps.CONTRACT, name="root"), replace_contract),
|
||||
(UOp(UOps.CONTRACT).name("root"), replace_contract),
|
||||
# VECTORIZE after REDUCEs -> one REDUCE (breaks TestConv.test_two_binops_no_rerun)
|
||||
(UPat(UOps.VECTORIZE, name="cst", src=UPat(UOps.REDUCE)), cast_reduce),
|
||||
])
|
||||
|
||||
reducer = PatternMatcher([
|
||||
(UPat(UOps.REDUCE, name="root"), replace_reduce),
|
||||
(UPat(UOps.WMMA, name="wmma"), expand_wmma),
|
||||
(UOp(UOps.REDUCE).name("root"), replace_reduce),
|
||||
(UOp(UOps.WMMA).name("wmma"), expand_wmma),
|
||||
# image indexing. TODO: why can't this just go after the float stuff?
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
|
||||
])
|
||||
@@ -279,7 +279,7 @@ float4_folding = PatternMatcher([
|
||||
(UOp(UOps.STORE, src=(UOp.var("buf"),
|
||||
UOp(UOps.EXPAND).name("ex"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
|
||||
# no ALU on float4 (float4 constructor doesn't work in METAL/GPU)
|
||||
(UPat(UOps.ALU, name="alu"), no_float4_alu),
|
||||
(UOp(UOps.ALU).name("alu"), no_float4_alu),
|
||||
])
|
||||
|
||||
# ***** transcendental *****
|
||||
@@ -352,15 +352,13 @@ constant_folder = PatternMatcher([
|
||||
# threefry
|
||||
(UOp(UOps.ALU, dtype=dtypes.uint64, src=(UOp.var("x"), UOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# arange loop folding (early)
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
|
||||
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
|
||||
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
|
||||
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
||||
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(BinaryOps.MUL,
|
||||
UOp.cvar("mval"), UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
|
||||
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None,0)), loop_collapse),
|
||||
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(UnaryOps.NEG,
|
||||
UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
|
||||
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None, 0)),
|
||||
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
@@ -387,16 +385,16 @@ constant_folder = PatternMatcher([
|
||||
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s'))), lambda c,s: -s if -(s.src[1].arg-1) >= c.arg else None),
|
||||
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.src[1].arg-1+c2.arg) >= c.arg else None),
|
||||
# const rules
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UOp(UOps.GEP, src=(UOp.cvar("c"),)).name("root"), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
||||
(UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC).name("acc"), UOp.var("acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
||||
(UOp(UOps.PHI, src=(UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)), UOp.var("x"))), lambda x: x),
|
||||
(UOp(UOps.PHI, src=(UOp.cvar(), UOp.var("x"))), lambda x: x),
|
||||
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
||||
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
||||
(UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)).name("root"), lambda root: UOp.cast(root.src[0], root.dtype)),
|
||||
(UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: UOp.const(root.dtype, x.arg)),
|
||||
# max -2147483648
|
||||
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
||||
# bool < False is always false, True < bool is always false
|
||||
@@ -466,17 +464,15 @@ constant_folder = PatternMatcher([
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
|
||||
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
||||
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
|
||||
(UPat(UOps.VECTORIZE, name="root", src=tuple(
|
||||
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UPat(UOps.VECTORIZE, name="root", src=tuple(
|
||||
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
|
||||
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
|
||||
# NEG/CMPLT -> CMPLT
|
||||
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(UPat(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(UOp(UOps.CAST).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(UOp(UOps.VECTORIZE).name("root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
# fold gated LOAD/STORE
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),
|
||||
@@ -486,7 +482,7 @@ constant_folder = PatternMatcher([
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.bool, True)), UOp.store),
|
||||
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
||||
# remove NOOPs from SINK
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
(UOp(UOps.SINK).name("root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
|
||||
])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user