mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
simpler pattern matcher rules [run_process_replay] (#5620)
This commit is contained in:
@@ -140,8 +140,8 @@ constant_folder = PatternMatcher([
|
||||
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
|
||||
# VECTORIZE/GEP
|
||||
(UOp(UOps.GEP, src=(UOp(UOps.VECTORIZE).name("cast"),)).name("gep"), lambda gep, cast: cast.src[gep.arg]),
|
||||
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j)
|
||||
for j in range(i))), lambda x: x) for i in [2, 4, 8]],
|
||||
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j) for j in range(i))), lambda x: x) \
|
||||
for i in [2, 4, 8]],
|
||||
# tensor core with a 0 input is acc
|
||||
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
|
||||
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
|
||||
@@ -155,13 +155,10 @@ constant_folder = PatternMatcher([
|
||||
# threefry
|
||||
(UOp(UOps.ALU, dtype=dtypes.uint64, src=(UOp.var("x"), UOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# arange loop folding (early)
|
||||
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(BinaryOps.MUL,
|
||||
UOp.cvar("mval"), UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
|
||||
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None,0)), loop_collapse),
|
||||
(UOp.where(UOp.alu(BinaryOps.CMPLT, UOp.alu(BinaryOps.ADD, UOp.var("idx"), UOp.alu(UnaryOps.NEG,
|
||||
UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng"))),
|
||||
UOp.cvar("compval")), UOp.cvar("multconst"), UOp.const(None, 0)),
|
||||
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
||||
((UOp.var("idx") + UOp.cvar("mval") * UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng")).lt(UOp.cvar("compval")).where(
|
||||
UOp.cvar("multconst"), UOp.const(None, 0)), loop_collapse),
|
||||
((UOp.var("idx") - UOp(UOps.RANGE, src=(UOp.var("loop_start"), UOp.var("loop_end"))).name("rng")).lt(UOp.cvar("compval")).where(
|
||||
UOp.cvar("multconst"), UOp.const(None, 0)), lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
@@ -260,13 +257,13 @@ constant_folder = PatternMatcher([
|
||||
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
|
||||
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
||||
# (x+x*c0)-> x*(c0+1)
|
||||
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
|
||||
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)),
|
||||
# x!=0 -> (bool)x
|
||||
(UOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
||||
# bool != 1 -> not bool
|
||||
(UOp.var("x", dtype=dtypes.bool).ne(1), lambda x: -x),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("gate").where(UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
|
||||
|
||||
Reference in New Issue
Block a user