From 624bc899102a58df15557659a45f0cab5f745fee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:54:02 +0100 Subject: [PATCH] PTX - implement float 4, ptr arithmetics and other speed improvements (#3775) * ptx float4 implementation * remove from cache when trimming uops * Gate for float4 * Linting fix * disable test reasonable time for ptx * import getenv * Update uops.py * linter * Add div test for half * upcast if op does not support operation * fix offset * Run only if dtype supported * zero out registers when accessing by pred + cleanup * Remove trailing whitespace * revert * spacing fix * move cache clearing outside loop * did this suddenly start working? * unused import removed * Remove cast * Use pattern matching * linting --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- tinygrad/codegen/uops.py | 1 + tinygrad/renderer/assembly.py | 181 ++++++++++++++++++++++------------ tinygrad/runtime/ops_cuda.py | 3 +- 3 files changed, 118 insertions(+), 67 deletions(-) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 13d356bf11..100446d1db 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -170,6 +170,7 @@ class UOpGraph: if len(nu) == len(self.uops): break if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") self.uops = nu + self.saved_exprs = {k:v for k,v in self.saved_exprs.items() if v in nu} # optional def type_verify(self): diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 36e259e35f..9b27f344be 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -1,14 +1,17 @@ -from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple +from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Set import functools, struct from collections import defaultdict from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT -from tinygrad.codegen.uops import UOpGraph +from tinygrad.codegen.uops import UOpGraph, PatternMatcher -def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) -def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) -def trunc_float(x, fmt): return struct.unpack(fmt, struct.pack(fmt, x))[0] +def render_val(x, dtype): + if dtypes.is_float(dtype): + if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) + elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1]) + return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) + return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) @@ -23,6 +26,7 @@ class AssemblyLanguage(NamedTuple): const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move asm_for_op: Dict[Op, Callable[...,str]] = {} types: Dict[DType, str] = INVERSE_DTYPES_DICT + supports_half: List[Op] = [] def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError() def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError() @@ -30,47 +34,77 @@ class AssemblyLanguage(NamedTuple): def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError() def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError() def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError() - def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: raise NotImplementedError() - def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: raise NotImplementedError() + def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError() + def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError() def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError() def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError() + def mem_type(self, dtype) -> str: raise NotImplementedError() def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: kernel:List[str] = [] bufs = [] + def eq_rep(root, x, y): + root.arg = BinaryOps.XOR + new = uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1) + return new + + def lt_rep(x, y): + new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)) + u.vin = (new, u.vin[1]) + u.arg = BinaryOps.MUL + + def ld_rep(root, x, y): + root.dtype = dtypes.uint8 + new = uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1) + ptr_ar(root) + return new + + def gate_rep(root, x, y, z, k): + new = uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root)) + root.vin = (x,y,z,new) + return ld_rep(root,x,y) + + def ptr_ar(root): + root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL + if root.vin[0].dtype.itemsize > 1: + val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root)) + ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root)) + else: ptr = root.vin[1] + if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:] + else: + zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root)) + bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root)) + fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root)) + root.vin = (fptr, zero) + root.vin[2:] + + matcher = PatternMatcher([ + ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep), + ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep), + ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, + "vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep), + ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep), + ({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar), + ({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar), + ]) + # here we do a pretransform on UOps to fix some shortcomings of PTX # all uops must be a register - # TODO: uops class should make these rewrites easier replace: Dict[UOp, UOp] = {} + seen: Set[UOp] = set() for u in uops: + if u in seen: continue + seen.add(u) for o,n in replace.items(): if o in u.vin and u is not n: u.vin = tuple(n if x == o else x for x in u.vin) - if u.uop is UOps.LOAD and u.dtype is dtypes.bool: - # rewrite load bool - if len(u.vin) == 4: - new = uops.add(UOps.CAST, dtypes.uint8, (u.vin[3],), insert_before=uops.uops.index(u)) - u.vin = u.vin[0:3] + (new,) - u.dtype = dtypes.uint8 - new = uops.add(UOps.CAST, dtypes.bool, (u,), insert_before=uops.uops.index(u)+1) - replace[u] = new - if u.uop is UOps.ALU and u.arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT} and u.vin[0].dtype is dtypes.bool: - if u.arg == BinaryOps.CMPEQ: - u.arg = BinaryOps.XOR - new = uops.add(UOps.ALU, dtypes.bool, (u,), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)+1) - replace[u] = new - if u.arg == BinaryOps.CMPLT: - new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)) - u.vin = (new, u.vin[1]) - u.arg = BinaryOps.MUL - #uops.print() + if rew := matcher.rewrite(u): replace[u] = rew def kk(*s: str): kernel.append("\n".join(s)) c: DefaultDict[str, int] = defaultdict(int) - r: Dict[UOp, str] = {} + r: Dict[UOp, Union[List[str], str]] = {} def ssa(u, prefix="t", dtype=None) -> str: nonlocal c, r prefix += f"_{dtype if dtype else lang.types[u.dtype]}_" @@ -92,7 +126,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: return out return lang.render_const(x, dtype) - def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): + def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): if atype == dtype: if u: r[u] = a return a @@ -113,45 +147,70 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: kk(f"{r_label[vin[0]]}:") elif uop == UOps.STORE: assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None - kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype)) - if len(vin) > 3: - assert vin[3].dtype is not None - pred = cast(r[vin[3]], dtypes.bool, vin[3].dtype, pred=True) - kk(*lang.render_store(loc, r[vin[2]], vin[0].dtype, gate=pred if len(vin)>3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else '')) + if vin[2].dtype.count > 1: + kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \ + f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};") + else: + kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg)) else: assert dtype is not None, f"None dtype for uop {uop}" if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop'))) elif uop == UOps.ALU: assert vin[0].dtype is not None + operands = [r[x] for x in vin] + lab = ssa(u, "alu") + if needs_upcast := dtype == dtypes.half and args not in lang.supports_half: + dtype = dtypes.float32 + out_lab, lab = lab, ssa(None, "alu_cast", lang.types[dtype]) + for i, op in enumerate(operands): + operands[i] = ssa(None, "alu_cast", lang.types[dtype]) + kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ: # pass in the other dtype here - kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype])) + kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype])) else: - kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype])) - elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};") + kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype])) + if needs_upcast: + kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype)) + elif uop == UOps.DEFINE_ACC: + if dtype.count > 1: + r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)] + for uu in r[u]: + kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};") + else: + kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};") elif uop == UOps.SPECIAL: assert args[1][0] != "i", "idx not supported" kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};") r[u] = "%" + args[1] kernel = [f".reg .u32 %{args[1]};"] + kernel - elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True) + elif uop == UOps.CONST: + if dtype.count > 1: + r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)] + else: + r[u] = const(args, dtype, mov=True) + elif uop == UOps.GEP: + r[u] = r[vin[0]][u.arg] elif uop == UOps.LOAD: assert vin[1].dtype is not None - val = ssa(u, 'val') - if len(vin) > 3: - assert vin[2].dtype is not None - pred = cast(r[vin[2]], dtypes.bool, vin[2].dtype, pred=True) - off = cast(r[vin[1]], dtypes.uint, vin[1].dtype) - kk(*lang.render_gep(loc:=ssa(None,'loc',lang.types[dtypes.ulong]), r[vin[0]], off if len(vin)>3 else cast(r[vin[1]], - dtypes.uint, vin[1].dtype), dtype), - *lang.render_load(loc, val, dtype, gate=pred if len(vin) > 3 else None, - alt=r[vin[3]] if len(vin) > 3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else '')) + if dtype.count > 1: + r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)] + if(len(vin)>3): + for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};") + kk((f"@{r[vin[2]]}"if len(vin) > 3 else "") + + f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];") + else: + kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None, + alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg)) elif uop == UOps.PHI: kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};") r[u] = r[vin[0]] elif uop in {UOps.CAST, UOps.BITCAST}: assert vin[0].dtype is not None - cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u) + if dtype.count>1: + r[u] = [r[x] for x in vin] # type: ignore + else: + cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u) elif uop == UOps.DEFINE_LOCAL: # TODO: we should sum these, and fetch 0xC000 from somewhere assert args[1]*dtype.itemsize <= 0xC000, "too large local" @@ -172,7 +231,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: return lang.render_kernel(kernel, function_name, bufs, c.items()) class PTXLanguage(AssemblyLanguage): - kernel_prefix = """.version 7.8 + kernel_prefix = """.version 7.5 .target TARGET .address_size 64 .visible .entry""" @@ -210,10 +269,8 @@ class PTXLanguage(AssemblyLanguage): const_requires_mov = [dtypes.half, dtypes.bool] def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: - if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}" - else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") + val = render_val(x, dtype) if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"] - if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"] return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val def render_local(self, dest, name, size, dtype) -> List[str]: @@ -223,31 +280,25 @@ class PTXLanguage(AssemblyLanguage): def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"] - def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: - # this cast is only required because of ocelot - if "s32" in offset: - return [f".reg .u32 {offset}_cast;", f"cvt.u32.s32 {offset}_cast, {offset};", f"mad.wide.u32 {loc}, {offset}_cast, {dtype.itemsize}, {base};"] - else: return [f"mad.wide.u32 {loc}, {offset}, {dtype.itemsize}, {base};"] - def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype] - def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: + def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: assert dtype is not dtypes.bool ret = [] - if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];", + if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]) - else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];") + else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];") return ret - def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: + def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype), - (f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val}_cast;"] - return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val};"] + (f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val}_cast;"] + return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"] def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"] - if atype == dtypes.bool: return [f"selp.{self.types[dtype]} {d}, {self.render_const(1, dtype)}, {self.render_const(0, dtype)}, {a};"] - if dtype == dtypes.bool: return [f"setp.ne.{self.types[atype]} {d}, {a}, {self.render_const(0, atype)};"] + if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"] + if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"] rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '') return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"] diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 2405145f84..914266f918 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -52,8 +52,7 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes: return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value) class PTXCompiler(Compiler): - linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], - supports_float4=False, shared_max=49152) + linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152) def __init__(self, arch:str): self.arch = arch PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)