From 31c8ba8b84eb3b7d0603f09e8be5e6ae934656b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Sat, 30 Mar 2024 03:42:39 +0100 Subject: [PATCH] Move transformations to PatternMatcher + clean up existing patterns (#3997) --- tinygrad/renderer/assembly.py | 41 ++++++++++++++++------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 2a224954af..949a0d6537 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -45,25 +45,31 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: def eq_rep(root, x, y): root.arg = BinaryOps.XOR - new = uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1) - return new + return uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1) def lt_rep(x, y): - new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)) - u.vin = (new, u.vin[1]) + u.vin = (uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)), u.vin[1]) u.arg = BinaryOps.MUL def ld_rep(root, x, y): root.dtype = dtypes.uint8 - new = uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1) ptr_ar(root) - return new + return uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1) + + def st_rep(root, z): + root.vin = root.vin[:2] + (uops.add(UOps.CAST, dtypes.uint8, (z,), insert_before=uops.uops.index(root)),) + ptr_ar(root) def gate_rep(root, x, y, z, k): - new = uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root)) - root.vin = (x,y,z,new) + root.vin = (x,y,z,uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root))) return ld_rep(root,x,y) + def half_alu_rep(root): + if root.arg not in lang.supports_half: + root.vin = tuple([uops.add(UOps.CAST, dtypes.float32, (vv,), insert_before=(uops.uops.index(root))) for vv in root.vin]) + root.dtype = dtypes.float32 + return uops.add(UOps.CAST, dtypes.half, (root,), insert_before=(uops.uops.index(root)+1)) + def ptr_ar(root): assert root.arg in {'.shared', '.global', None} if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL @@ -79,9 +85,12 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: matcher = PatternMatcher([ ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep), ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep), + ({"__name__": "root", "uop": UOps.ALU, "dtype": dtypes.half}, half_alu_rep), ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep), ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep), + ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})}, st_rep), + ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})}, st_rep), ({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar), ({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar), ]) @@ -155,21 +164,11 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop'))) elif uop is UOps.ALU: assert vin[0].dtype is not None - operands = [r[x] for x in vin] - lab = ssa(u, "alu") - if needs_upcast := dtype == dtypes.half and args not in lang.supports_half: - dtype = dtypes.float32 - out_lab, lab = lab, ssa(None, "alu_cast", lang.types[dtype]) - for i, op in enumerate(operands): - operands[i] = ssa(None, "alu_cast", lang.types[dtype]) - kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ: # pass in the other dtype here - kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype])) + kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype])) else: - kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype])) - if needs_upcast: - kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype)) + kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype])) elif uop is UOps.DEFINE_ACC: if dtype.count > 1: r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)] @@ -275,8 +274,6 @@ class PTXLanguage(AssemblyLanguage): else: return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"] def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: - if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype), - (f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val}_cast;"] return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"] def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: