diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 949a0d6537..f7915c3067 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -43,34 +43,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: kernel:List[str] = [] bufs = [] - def eq_rep(root, x, y): - root.arg = BinaryOps.XOR - return uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1) - - def lt_rep(x, y): - u.vin = (uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)), u.vin[1]) - u.arg = BinaryOps.MUL - - def ld_rep(root, x, y): - root.dtype = dtypes.uint8 - ptr_ar(root) - return uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1) - - def st_rep(root, z): - root.vin = root.vin[:2] + (uops.add(UOps.CAST, dtypes.uint8, (z,), insert_before=uops.uops.index(root)),) - ptr_ar(root) - - def gate_rep(root, x, y, z, k): - root.vin = (x,y,z,uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root))) - return ld_rep(root,x,y) - - def half_alu_rep(root): - if root.arg not in lang.supports_half: - root.vin = tuple([uops.add(UOps.CAST, dtypes.float32, (vv,), insert_before=(uops.uops.index(root))) for vv in root.vin]) - root.dtype = dtypes.float32 - return uops.add(UOps.CAST, dtypes.half, (root,), insert_before=(uops.uops.index(root)+1)) - - def ptr_ar(root): + def ptr_ar(root, uops): assert root.arg in {'.shared', '.global', None} if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root)) @@ -83,16 +56,22 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: root.vin = (fptr, zero) + root.vin[2:] matcher = PatternMatcher([ - ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep), - ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep), - ({"__name__": "root", "uop": UOps.ALU, "dtype": dtypes.half}, half_alu_rep), - ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, - "vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep), - ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep), - ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})}, st_rep), - ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})}, st_rep), - ({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar), - ({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar), + ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})}, + lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)), + ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, + lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)), + *[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op}, + lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),))) + for op in lang.asm_for_op.keys() if op not in lang.supports_half], + ({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool, + "vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, + lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)), + ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})}, + lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))), + ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})}, + lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)), + ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})}, + lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)), ]) # here we do a pretransform on UOps to fix some shortcomings of PTX @@ -106,6 +85,18 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: if o in u.vin and u is not n: u.vin = tuple(n if x == o else x for x in u.vin) if rew := matcher.rewrite(u): replace[u] = rew + + for o,n in replace.items(): + queue = [n] + while queue: + if all([qq in uops.uops for qq in queue[-1].vin]): + q = queue.pop() + new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([uops.uops.index(vv) for vv in q.vin])+1) + for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin) + else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops]) + uops.uops.remove(o) + + for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops) uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE})) def kk(*s: str): kernel.append("\n".join(s))