mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
[run_process_replay] style: clean up UPat
This commit is contained in:
@@ -142,9 +142,8 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
|
||||
constant_folder = PatternMatcher([
|
||||
# arange loop folding (early)
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, vin=(UPat(UOps.ALU, BinaryOps.CMPLT, vin=(
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=
|
||||
[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
|
||||
vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
|
||||
vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="phi_input", vin=(UPat(UOps.RANGE, name="loop"),)),
|
||||
@@ -160,8 +159,7 @@ constant_folder = PatternMatcher([
|
||||
(UPat(UOps.CAST, name="root", vin=(UPat(UOps.UNMUL, name="unmul"),)),
|
||||
lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
|
||||
# max on special can go away (TODO: special should be variable, same thing applies)
|
||||
(UPat(UOps.ALU, BinaryOps.MAX, [UPat(UOps.CONST, name="c"), UPat(UOps.SPECIAL, name="s")]),
|
||||
lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
||||
(UPat(UOps.ALU, BinaryOps.MAX, [UPat(UOps.CONST, name="c"), UPat(UOps.SPECIAL, name="s")]), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
||||
# const rules
|
||||
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
@@ -177,21 +175,17 @@ constant_folder = PatternMatcher([
|
||||
# -(-x) -> x
|
||||
(UPat(UOps.ALU, UnaryOps.NEG, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)))), lambda x: x),
|
||||
# x+-y -> x-y
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, (UPat(name="x"), UPat(UOps.ALU, UnaryOps.NEG, name="my"))),
|
||||
lambda x, my: x-my.vin[0]),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, (UPat(name="x"), UPat(UOps.ALU, UnaryOps.NEG, name="my"))), lambda x, my: x-my.vin[0]),
|
||||
# -1*x -> -x
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, -1)]), lambda x: -x),
|
||||
# bool < False is always false, True < bool is always false
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(), UPat(UOps.CONST, False, name="x", dtype=dtypes.bool))), lambda x: x),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.CONST, True, name="x", dtype=dtypes.bool), UPat())),
|
||||
lambda x: UOp.const(dtypes.bool, False)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.CONST, True, name="x", dtype=dtypes.bool), UPat())), lambda x: UOp.const(dtypes.bool, False)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, (UPat(), UPat(name="val"), UPat(name="val"))), lambda val: val),
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, (UPat(UOps.CONST, name="gate"), UPat(name="c0"), UPat(name="c1"))),
|
||||
lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, (UPat(UOps.CONST, name="gate"), UPat(name="c0"), UPat(name="c1"))), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)),
|
||||
lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
||||
(UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
||||
# ** self folding **
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="x"), UPat(UOps.CONST, 0)]), lambda x: x), # x+0 -> x or 0+x -> x
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, 1)]), lambda x: x), # x*1 -> x or 1*x -> x
|
||||
@@ -205,15 +199,13 @@ constant_folder = PatternMatcher([
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"),
|
||||
UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)),
|
||||
# ** two stage add/sub folding **
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.ADD,
|
||||
[UPat(name="x"), UPat(UOps.CONST, name="c1")]), UPat(UOps.CONST, name="c2")]),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="x"), UPat(UOps.CONST, name="c1")]), UPat(UOps.CONST, name="c2")]),
|
||||
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.SUB,
|
||||
(UPat(name="x"), UPat(UOps.CONST, name="c1"))), UPat(UOps.CONST, name="c2")]),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(UOps.CONST, name="c1"))), UPat(UOps.CONST, name="c2")]),
|
||||
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.ALU, TernaryOps.WHERE,
|
||||
(UPat(name="gate"), UPat(name="alt"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))))),
|
||||
(UPat(name="gate"), UPat(name="alt"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))))),
|
||||
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
|
||||
# store float4/float2 directly (remove CAST/GEP)
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin=
|
||||
@@ -223,15 +215,12 @@ constant_folder = PatternMatcher([
|
||||
tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(2))))),
|
||||
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
||||
# CAST-PHI-GEP -> PHI-CAST
|
||||
(UPat(UOps.CAST, name="root", vin=
|
||||
tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UPat(UOps.CAST, name="root", vin=
|
||||
tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
||||
# NEG/CMPLT -> CMPLT
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)),
|
||||
UPat(UOps.CONST, name="c", dtype=dtypes.int))),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)), UPat(UOps.CONST, name="c", dtype=dtypes.int))),
|
||||
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
|
||||
|
||||
@@ -235,27 +235,25 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
||||
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"),
|
||||
lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD,
|
||||
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
|
||||
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool,
|
||||
vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(),UPat())),
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
||||
# ptr_ar (load/store)
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
|
||||
Reference in New Issue
Block a user