proper wmma (#2245)

* proper wmma

* hip cast

* bugfixes

* bugfix

* that bug is fixed

---------

Co-authored-by: George Hotz <george@tinygrad.org>
This commit is contained in:
George Hotz
2023-11-09 15:15:18 -08:00
committed by GitHub
parent b7a31fb708
commit 80bf0b8586
6 changed files with 65 additions and 32 deletions

View File

@@ -27,13 +27,19 @@ repos:
pass_filenames: false pass_filenames: false
- id: tests - id: tests
name: subset of (CPU) tests name: subset of (CPU) tests
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py test/external/test_example.py entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
language: system
always_run: true
pass_filenames: false
- id: example
name: multi device tests
entry: python3 test/external/test_example.py
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false
- id: pylint - id: pylint
name: pylint name: pylint
entry: python -m pylint tinygrad/ entry: python3 -m pylint tinygrad/
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false

View File

@@ -67,13 +67,13 @@ with open("/tmp/cc2.elf", "wb") as f:
f.write(asm) f.write(asm)
print(colored("creating CLProgram", "green")) print(colored("creating CLProgram", "green"))
prg = CLProgram("code", asm, binary=True) prg = CLProgram("code", asm)
print(colored("running program", "green")) print(colored("running program", "green"))
G = 512 G = 512
FLOPS *= 100000*G*G # loop * global_size FLOPS *= 100000*G*G # loop * global_size
for i in range(3): for i in range(3):
tm = prg([G//256, G], [256, 1], buf, wait=True) tm = prg(buf, global_size=[G//256, G, 1], local_size=[256, 1, 1], wait=True)
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS") print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
print(colored("transferring buffer", "green")) print(colored("transferring buffer", "green"))

View File

@@ -24,7 +24,7 @@ class TestLinearizerFailures(unittest.TestCase):
lin = Linearizer(ast) lin = Linearizer(ast)
assert fuzz_linearizer(lin) != "PASS" assert fuzz_linearizer(lin) != "PASS"
@unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU", "CLANG"], "fails on these backends") @unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU"], "fails on these backends")
def test_failure_3(self): def test_failure_3(self):
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1)) ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1))
lin = Linearizer(ast) lin = Linearizer(ast)

View File

@@ -209,15 +209,11 @@ class Linearizer(Kernel):
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
self.loop_uops.update(new_loops) self.loop_uops.update(new_loops)
return tuple(new_loops.values()) return tuple(new_loops.values())
def end_loop(xx:List[Variable]):
for x in xx[::-1]:
if not isinstance(x, NumNode) and x.expr is not None:
loop_uop = self.loop_uops[x.expr]
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
# set global/local size # set global/local size
self.global_size: Optional[List[int]] = None self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None self.local_size: Optional[List[int]] = None
global_loop_ctx: Tuple[UOp, ...] = tuple()
if self.dont_use_locals: if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1] self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
@@ -228,7 +224,7 @@ class Linearizer(Kernel):
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else: else:
render_loop(loop_global_idxs+loop_local_idxs) global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs)
# parse AST # parse AST
loaded_buffers = {} loaded_buffers = {}
@@ -296,8 +292,16 @@ class Linearizer(Kernel):
for y in range(by): for y in range(by):
for x in range(bx): for x in range(bx):
for j in range(acc_reds): for j in range(acc_reds):
# TODO: make this a proper op with PHI node op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]]
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) if self.opts.device != "HIP":
ops = tuple(op1+op2+op3)
else:
ops = (self.uop(UOps.CAST, dtypes._half16, tuple(op1)),
self.uop(UOps.CAST, dtypes._half16, tuple(op2)),
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
for z in range(cast(DType, ret.dtype).sz):
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx)
i += wmma_sz[2] i += wmma_sz[2]
else: else:
if locals_to_store: if locals_to_store:
@@ -309,10 +313,9 @@ class Linearizer(Kernel):
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
# run early AST (with reduce) # run early AST (with reduce)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx)
# end the reduce loop # end the reduce loop
end_loop(reduce_idxs)
self.load_cache.clear() self.load_cache.clear()
# end the local loop, do the local reduce # end the local loop, do the local reduce
@@ -320,7 +323,6 @@ class Linearizer(Kernel):
fake_global_idxs = [x*0 for x in global_idxs] fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
self.uop(UOps.BARRIER, None, (), cachable=False) self.uop(UOps.BARRIER, None, (), cachable=False)
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
if self.opts.has_local: if self.opts.has_local:
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape) fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:] fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
@@ -356,24 +358,23 @@ class Linearizer(Kernel):
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
# end the late reduce loop # end the late reduce loop
end_loop(end_local_idxs)
self.load_cache.clear() self.load_cache.clear()
# load latebufs # load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
# run late AST # run late AST
val = self.ast_parse(self.ast, acc, None, loaded_buffers) val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx)
# store # store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# end the global (and maybe local) loop # end the if statement if we used it
if if_gate: self.uop(UOps.END, None, (if_gate,)) if if_gate: self.uop(UOps.END, None, (if_gate,))
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
# (recursively) remove childless uops # (recursively) remove childless uops
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL} # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
while 1: while 1:
has_child: Set[UOp] = set() has_child: Set[UOp] = set()
for ru in self.uops: for ru in self.uops:
@@ -384,7 +385,28 @@ class Linearizer(Kernel):
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu self.uops = nu
def get_recursive_deps(x:UOp) -> List[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return sorted(list(deps), key=lambda x: x.num)
# add END of loops after the last thing that (recursively) depends on them
for u in self.uops:
if u.uop == UOps.LOOP:
last_phi = self.uops.index(get_recursive_deps(u)[-1])
at_end = self.uops[last_phi+1:]
self.uops = self.uops[:last_phi+1]
self.uop(UOps.END, None, (u,), cachable=False)
self.uops += at_end
# maybe graph the uops # maybe graph the uops
if DEBUG >= 5:
for u in self.uops: print(u)
if getenv("GRAPHUOPS"): if getenv("GRAPHUOPS"):
from tinygrad.graph import graph_uops from tinygrad.graph import graph_uops
graph_uops(self.uops) graph_uops(self.uops)
@@ -415,7 +437,7 @@ class Linearizer(Kernel):
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if cachable and key in self.saved_exprs: return self.saved_exprs[key] if cachable and key in self.saved_exprs: return self.saved_exprs[key]
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops))) self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
if DEBUG >= 5: print(self.uops[-1]) #if DEBUG >= 5: print(self.uops[-1])
if cachable: self.saved_exprs[key] = self.uops[-1] if cachable: self.saved_exprs[key] = self.uops[-1]
return self.uops[-1] return self.uops[-1]

View File

@@ -130,8 +130,10 @@ class dtypes:
# NOTE: these are internal dtypes, should probably check for that # NOTE: these are internal dtypes, should probably check for that
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2) _int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4) _half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_half16: Final[DType] = DType(0, 2*16, "half16", None, 16)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2) _float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4) _float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_float8: Final[DType] = DType(4, 4*8, "float8", None, 8)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None) _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# NOTE: these are image dtypes # NOTE: these are image dtypes

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
import math import math
from collections import defaultdict from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.codegen.linearizer import UOps, UOp
@@ -46,6 +46,8 @@ class CStyleLanguage(NamedTuple):
if len(x) == 1: return f"({var_dtype.name})({x[0]})" if len(x) == 1: return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform" assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._half16: return f"{{{','.join(f'(half){x}' for x in x)}}}"
if var_dtype == dtypes._float8: return f"{{{','.join(x)}}}"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})" if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})" if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})" if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
@@ -141,21 +143,19 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
kk("}") kk("}")
elif uop == UOps.WMMA: elif uop == UOps.WMMA:
if args[0] == "METAL": if args[0] == "METAL":
assert dtype == dtypes._float2, "output dtype of METAL TC is _float2"
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
output = ssa(u, 'wmma')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};")
kk("{ simdgroup_float8x8 a,b,c;") kk("{ simdgroup_float8x8 a,b,c;")
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};") kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};") kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};") kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
kk("simdgroup_multiply_accumulate(c, a, b, c);") kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}") kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
elif args[0] == "HIP": elif args[0] == "HIP":
kk("{") assert dtype == dtypes._float8, "output dtype of HIP TC is _float8"
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};") kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
kk("}")
else: else:
raise NotImplementedError(f"WMMA not implemented for {args}") raise NotImplementedError(f"WMMA not implemented for {args}")
elif uop == UOps.ALU: elif uop == UOps.ALU:
@@ -205,7 +205,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
bufs.append(args) bufs.append(args)
r[u] = args[0] r[u] = args[0]
elif uop == UOps.GEP: elif uop == UOps.GEP:
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}" if cast(DType, vin[0].dtype).sz > 4:
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
else:
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
else: else:
raise RuntimeError(f"failed to render {uop}") raise RuntimeError(f"failed to render {uop}")