diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index f1ea40f209..6a18b08157 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -142,9 +142,8 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst): constant_folder = PatternMatcher([ # arange loop folding (early) (UPat(UOps.ALU, TernaryOps.WHERE, vin=(UPat(UOps.ALU, BinaryOps.CMPLT, vin=( - UPat(UOps.ALU, BinaryOps.ADD, vin= - [UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, - vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]), + UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, + vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]), UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse), # sum collapse to mul (with possible GEP) (UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="phi_input", vin=(UPat(UOps.RANGE, name="loop"),)), @@ -160,8 +159,7 @@ constant_folder = PatternMatcher([ (UPat(UOps.CAST, name="root", vin=(UPat(UOps.UNMUL, name="unmul"),)), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))), # max on special can go away (TODO: special should be variable, same thing applies) - (UPat(UOps.ALU, BinaryOps.MAX, [UPat(UOps.CONST, name="c"), UPat(UOps.SPECIAL, name="s")]), - lambda c,s: c if (s.arg[2]-1) <= c.arg else None), + (UPat(UOps.ALU, BinaryOps.MAX, [UPat(UOps.CONST, name="c"), UPat(UOps.SPECIAL, name="s")]), lambda c,s: c if (s.arg[2]-1) <= c.arg else None), # const rules (UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)), (UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)), @@ -177,21 +175,17 @@ constant_folder = PatternMatcher([ # -(-x) -> x (UPat(UOps.ALU, UnaryOps.NEG, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)))), lambda x: x), # x+-y -> x-y - (UPat(UOps.ALU, BinaryOps.ADD, (UPat(name="x"), UPat(UOps.ALU, UnaryOps.NEG, name="my"))), - lambda x, my: x-my.vin[0]), + (UPat(UOps.ALU, BinaryOps.ADD, (UPat(name="x"), UPat(UOps.ALU, UnaryOps.NEG, name="my"))), lambda x, my: x-my.vin[0]), # -1*x -> -x (UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, -1)]), lambda x: -x), # bool < False is always false, True < bool is always false (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(), UPat(UOps.CONST, False, name="x", dtype=dtypes.bool))), lambda x: x), - (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.CONST, True, name="x", dtype=dtypes.bool), UPat())), - lambda x: UOp.const(dtypes.bool, False)), + (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.CONST, True, name="x", dtype=dtypes.bool), UPat())), lambda x: UOp.const(dtypes.bool, False)), # a conditional with the same results either way is a noop, also fold const conditionals (UPat(UOps.ALU, TernaryOps.WHERE, (UPat(), UPat(name="val"), UPat(name="val"))), lambda val: val), - (UPat(UOps.ALU, TernaryOps.WHERE, (UPat(UOps.CONST, name="gate"), UPat(name="c0"), UPat(name="c1"))), - lambda gate, c0, c1: c0 if gate.arg else c1), + (UPat(UOps.ALU, TernaryOps.WHERE, (UPat(UOps.CONST, name="gate"), UPat(name="c0"), UPat(name="c1"))), lambda gate, c0, c1: c0 if gate.arg else c1), # ** constant folding ** - (UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)), - lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))), + (UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))), # ** self folding ** (UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="x"), UPat(UOps.CONST, 0)]), lambda x: x), # x+0 -> x or 0+x -> x (UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, 1)]), lambda x: x), # x*1 -> x or 1*x -> x @@ -205,15 +199,13 @@ constant_folder = PatternMatcher([ (UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)), # ** two stage add/sub folding ** - (UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.ADD, - [UPat(name="x"), UPat(UOps.CONST, name="c1")]), UPat(UOps.CONST, name="c2")]), + (UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="x"), UPat(UOps.CONST, name="c1")]), UPat(UOps.CONST, name="c2")]), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))), - (UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.SUB, - (UPat(name="x"), UPat(UOps.CONST, name="c1"))), UPat(UOps.CONST, name="c2")]), + (UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(UOps.CONST, name="c1"))), UPat(UOps.CONST, name="c2")]), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))), # TODO: can do the invert of this (flip alt/load) when we fix double ops (UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.ALU, TernaryOps.WHERE, - (UPat(name="gate"), UPat(name="alt"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))))), + (UPat(name="gate"), UPat(name="alt"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))))), lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))), # store float4/float2 directly (remove CAST/GEP) (UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin= @@ -223,15 +215,12 @@ constant_folder = PatternMatcher([ tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(2))))), lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))), # CAST-PHI-GEP -> PHI-CAST - (UPat(UOps.CAST, name="root", vin= - tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))), + (UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))), lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))), - (UPat(UOps.CAST, name="root", vin= - tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))), + (UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))), lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))), # NEG/CMPLT -> CMPLT - (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)), - UPat(UOps.CONST, name="c", dtype=dtypes.int))), + (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)), UPat(UOps.CONST, name="c", dtype=dtypes.int))), lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)), # cast NOOP (NOTE: it's str to deal with PtrDType) (UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None), diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 0eb8c3cc4a..95930589a0 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -235,27 +235,25 @@ ptx_matcher = PatternMatcher([ (UPat(UOps.ALU, BinaryOps.DIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]), lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)), - (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), - lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)), + (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)), (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"), - lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)), + lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)), (UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"), lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)), *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"), lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),))) for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half], - (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, - vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))), - lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)), + (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))), + lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)), (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(),UPat())), - lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))), + lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))), (UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())), - lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), + lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), (UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))), - lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), + lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), (UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))), - lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)), + lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)), # ptr_ar (load/store) (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),