From 80bf0b85861ade92c637de4ce948c25f2a8d02a9 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 9 Nov 2023 15:15:18 -0800 Subject: [PATCH] proper wmma (#2245) * proper wmma * hip cast * bugfixes * bugfix * that bug is fixed --------- Co-authored-by: George Hotz --- .pre-commit-config.yaml | 10 ++++-- extra/rocm/rdna3/asm.py | 4 +-- test/test_linearizer_failures.py | 2 +- tinygrad/codegen/linearizer.py | 56 ++++++++++++++++++++++---------- tinygrad/helpers.py | 2 ++ tinygrad/renderer/cstyle.py | 23 +++++++------ 6 files changed, 65 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b7047675e..77d40565a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,13 +27,19 @@ repos: pass_filenames: false - id: tests name: subset of (CPU) tests - entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py test/external/test_example.py + entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py + language: system + always_run: true + pass_filenames: false + - id: example + name: multi device tests + entry: python3 test/external/test_example.py language: system always_run: true pass_filenames: false - id: pylint name: pylint - entry: python -m pylint tinygrad/ + entry: python3 -m pylint tinygrad/ language: system always_run: true pass_filenames: false diff --git a/extra/rocm/rdna3/asm.py b/extra/rocm/rdna3/asm.py index f0e9f15c1c..2f6ad13264 100644 --- a/extra/rocm/rdna3/asm.py +++ b/extra/rocm/rdna3/asm.py @@ -67,13 +67,13 @@ with open("/tmp/cc2.elf", "wb") as f: f.write(asm) print(colored("creating CLProgram", "green")) -prg = CLProgram("code", asm, binary=True) +prg = CLProgram("code", asm) print(colored("running program", "green")) G = 512 FLOPS *= 100000*G*G # loop * global_size for i in range(3): - tm = prg([G//256, G], [256, 1], buf, wait=True) + tm = prg(buf, global_size=[G//256, G, 1], local_size=[256, 1, 1], wait=True) print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS") print(colored("transferring buffer", "green")) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 82e310c00f..4010282d21 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -24,7 +24,7 @@ class TestLinearizerFailures(unittest.TestCase): lin = Linearizer(ast) assert fuzz_linearizer(lin) != "PASS" - @unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU", "CLANG"], "fails on these backends") + @unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU"], "fails on these backends") def test_failure_3(self): ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1)) lin = Linearizer(ast) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 7d59f85976..f25058bf2c 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -209,15 +209,11 @@ class Linearizer(Kernel): self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} self.loop_uops.update(new_loops) return tuple(new_loops.values()) - def end_loop(xx:List[Variable]): - for x in xx[::-1]: - if not isinstance(x, NumNode) and x.expr is not None: - loop_uop = self.loop_uops[x.expr] - if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,)) # set global/local size self.global_size: Optional[List[int]] = None self.local_size: Optional[List[int]] = None + global_loop_ctx: Tuple[UOp, ...] = tuple() if self.dont_use_locals: self.global_size = [x.max+1 for x in loop_global_idxs][::-1] self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) @@ -228,7 +224,7 @@ class Linearizer(Kernel): self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) else: - render_loop(loop_global_idxs+loop_local_idxs) + global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs) # parse AST loaded_buffers = {} @@ -296,8 +292,16 @@ class Linearizer(Kernel): for y in range(by): for x in range(bx): for j in range(acc_reds): - # TODO: make this a proper op with PHI node - self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) + op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] + if self.opts.device != "HIP": + ops = tuple(op1+op2+op3) + else: + ops = (self.uop(UOps.CAST, dtypes._half16, tuple(op1)), + self.uop(UOps.CAST, dtypes._half16, tuple(op2)), + self.uop(UOps.CAST, dtypes._float8, tuple(op3))) + ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) + for z in range(cast(DType, ret.dtype).sz): + acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx) i += wmma_sz[2] else: if locals_to_store: @@ -309,10 +313,9 @@ class Linearizer(Kernel): loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) # run early AST (with reduce) - self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) + self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx) # end the reduce loop - end_loop(reduce_idxs) self.load_cache.clear() # end the local loop, do the local reduce @@ -320,7 +323,6 @@ class Linearizer(Kernel): fake_global_idxs = [x*0 for x in global_idxs] self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators self.uop(UOps.BARRIER, None, (), cachable=False) - end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if? if self.opts.has_local: fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape) fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:] @@ -356,24 +358,23 @@ class Linearizer(Kernel): self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore # end the late reduce loop - end_loop(end_local_idxs) self.load_cache.clear() # load latebufs loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # run late AST - val = self.ast_parse(self.ast, acc, None, loaded_buffers) + val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx) # store self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) - # end the global (and maybe local) loop + # end the if statement if we used it if if_gate: self.uop(UOps.END, None, (if_gate,)) - end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs) # (recursively) remove childless uops - UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL} + # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that + UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL} while 1: has_child: Set[UOp] = set() for ru in self.uops: @@ -384,7 +385,28 @@ class Linearizer(Kernel): if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") self.uops = nu + def get_recursive_deps(x:UOp) -> List[UOp]: + deps = set([x]) + ssize = 0 + while ssize != len(deps): + ssize = len(deps) + for u in self.uops: + if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])): + deps.add(u) + return sorted(list(deps), key=lambda x: x.num) + + # add END of loops after the last thing that (recursively) depends on them + for u in self.uops: + if u.uop == UOps.LOOP: + last_phi = self.uops.index(get_recursive_deps(u)[-1]) + at_end = self.uops[last_phi+1:] + self.uops = self.uops[:last_phi+1] + self.uop(UOps.END, None, (u,), cachable=False) + self.uops += at_end + # maybe graph the uops + if DEBUG >= 5: + for u in self.uops: print(u) if getenv("GRAPHUOPS"): from tinygrad.graph import graph_uops graph_uops(self.uops) @@ -415,7 +437,7 @@ class Linearizer(Kernel): if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] if cachable and key in self.saved_exprs: return self.saved_exprs[key] self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops))) - if DEBUG >= 5: print(self.uops[-1]) + #if DEBUG >= 5: print(self.uops[-1]) if cachable: self.saved_exprs[key] = self.uops[-1] return self.uops[-1] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index efad1fcee3..2db293b6f7 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -130,8 +130,10 @@ class dtypes: # NOTE: these are internal dtypes, should probably check for that _int2: Final[DType] = DType(2, 4*2, "int2", None, 2) _half4: Final[DType] = DType(0, 2*4, "half4", None, 4) + _half16: Final[DType] = DType(0, 2*16, "half16", None, 16) _float2: Final[DType] = DType(4, 4*2, "float2", None, 2) _float4: Final[DType] = DType(4, 4*4, "float4", None, 4) + _float8: Final[DType] = DType(4, 4*8, "float8", None, 8) _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None) # NOTE: these are image dtypes diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index fa41755ce6..727e005449 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict +from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast import math from collections import defaultdict from tinygrad.codegen.linearizer import UOps, UOp @@ -46,6 +46,8 @@ class CStyleLanguage(NamedTuple): if len(x) == 1: return f"({var_dtype.name})({x[0]})" assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" assert self.float4 is not None, "cast is not supported on this platform" + if var_dtype == dtypes._half16: return f"{{{','.join(f'(half){x}' for x in x)}}}" + if var_dtype == dtypes._float8: return f"{{{','.join(x)}}}" if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})" if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})" if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})" @@ -141,21 +143,19 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu kk("}") elif uop == UOps.WMMA: if args[0] == "METAL": + assert dtype == dtypes._float2, "output dtype of METAL TC is _float2" # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) + output = ssa(u, 'wmma') + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};") kk("{ simdgroup_float8x8 a,b,c;") kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};") kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};") kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};") kk("simdgroup_multiply_accumulate(c, a, b, c);") - kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}") + kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}") elif args[0] == "HIP": - kk("{") - kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};") - kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};") - kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};") - kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);") - for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];") - kk("}") + assert dtype == dtypes._float8, "output dtype of HIP TC is _float8" + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") else: raise NotImplementedError(f"WMMA not implemented for {args}") elif uop == UOps.ALU: @@ -205,7 +205,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu bufs.append(args) r[u] = args[0] elif uop == UOps.GEP: - r[u] = f"({r[vin[0]]}).{'xyzw'[args]}" + if cast(DType, vin[0].dtype).sz > 4: + r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP + else: + r[u] = f"({r[vin[0]]}).{'xyzw'[args]}" else: raise RuntimeError(f"failed to render {uop}")