diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index d9602988a9..3b78aebc9f 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element -from tinygrad.ops import UPat, PatternMatcher, graph_rewrite +from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -284,6 +284,11 @@ def get_extra_patterns(ops): UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None), (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", src=(UPat.var("div"), UPat.cvar("const"))), lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None)] + if UnaryOps.NEG in ops: + pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] + if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))] + if TernaryOps.MULACC in ops: + pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))] return PatternMatcher(pat) # ***** threefry ***** @@ -361,9 +366,10 @@ def no_vectorized_wmma(wmma:UOp): # this is symbolic 2.0 constant_folder = PatternMatcher([ - # bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly - (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)), - (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)), + # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly + (UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y), + (UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y), + (UPat.var('x', dtype=dtypes.bool).max(UPat.var('y')), lambda x,y: x|y), # self ASSIGN is just self (UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), # ASSIGN to global is just self diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 7e491514a5..bdb0719fbc 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -33,19 +33,15 @@ asm_for_op: Dict[Op, Callable] = { } supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] -shiftable_consts = set([2**i for i in range(64)]) ptx_matcher = constant_folder+PatternMatcher([ - (UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"), - lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)), - (UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat.var("x", dtype=dtypes.bool),UPat.var("y")), name="root"), - lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x, UOp.const(dtypes.bool, True)), BinaryOps.CMPNE), y), BinaryOps.MUL)), - (UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("non_muls"), UPat(UOps.ALU, arg=BinaryOps.MUL, name="muls")], name="root"), - lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)), + # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) + (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y), + (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y), + # upcast to float32 all the ops that don't support half *[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"), - lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half))) + lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))) for op in asm_for_op.keys() if op not in supports_half], - (UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX), - lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)), + # fix the gates for load/store (low quality!) (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))), lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)), (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())), @@ -66,8 +62,7 @@ ptx_matcher = constant_folder+PatternMatcher([ lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64), UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])), - (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), - UPat.var("alu"))), # no const here + (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))), lambda root, alu: UOp(root.op, root.dtype, (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), UOp.const(dtypes.int64, 0))+root.src[2:])), diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 980caaba3f..cf421d19ef 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -112,9 +112,10 @@ class CStyleLanguage(Renderer): code_for_op: Dict = { UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"(1/{x})", + UnaryOps.NEG: lambda x,dtype: f"-{x}", UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})", BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})", - BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", + BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",