proper wmma (#2245)

* proper wmma

* hip cast

* bugfixes

* bugfix

* that bug is fixed

---------

Co-authored-by: George Hotz <george@tinygrad.org>
This commit is contained in:
George Hotz
2023-11-09 15:15:18 -08:00
committed by GitHub
parent b7a31fb708
commit 80bf0b8586
6 changed files with 65 additions and 32 deletions

View File

@@ -27,13 +27,19 @@ repos:
pass_filenames: false
- id: tests
name: subset of (CPU) tests
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py test/external/test_example.py
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
language: system
always_run: true
pass_filenames: false
- id: example
name: multi device tests
entry: python3 test/external/test_example.py
language: system
always_run: true
pass_filenames: false
- id: pylint
name: pylint
entry: python -m pylint tinygrad/
entry: python3 -m pylint tinygrad/
language: system
always_run: true
pass_filenames: false

View File

@@ -67,13 +67,13 @@ with open("/tmp/cc2.elf", "wb") as f:
f.write(asm)
print(colored("creating CLProgram", "green"))
prg = CLProgram("code", asm, binary=True)
prg = CLProgram("code", asm)
print(colored("running program", "green"))
G = 512
FLOPS *= 100000*G*G # loop * global_size
for i in range(3):
tm = prg([G//256, G], [256, 1], buf, wait=True)
tm = prg(buf, global_size=[G//256, G, 1], local_size=[256, 1, 1], wait=True)
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
print(colored("transferring buffer", "green"))

View File

@@ -24,7 +24,7 @@ class TestLinearizerFailures(unittest.TestCase):
lin = Linearizer(ast)
assert fuzz_linearizer(lin) != "PASS"
@unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU", "CLANG"], "fails on these backends")
@unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU"], "fails on these backends")
def test_failure_3(self):
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1))
lin = Linearizer(ast)

View File

@@ -209,15 +209,11 @@ class Linearizer(Kernel):
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
self.loop_uops.update(new_loops)
return tuple(new_loops.values())
def end_loop(xx:List[Variable]):
for x in xx[::-1]:
if not isinstance(x, NumNode) and x.expr is not None:
loop_uop = self.loop_uops[x.expr]
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
# set global/local size
self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None
global_loop_ctx: Tuple[UOp, ...] = tuple()
if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
@@ -228,7 +224,7 @@ class Linearizer(Kernel):
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
render_loop(loop_global_idxs+loop_local_idxs)
global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs)
# parse AST
loaded_buffers = {}
@@ -296,8 +292,16 @@ class Linearizer(Kernel):
for y in range(by):
for x in range(bx):
for j in range(acc_reds):
# TODO: make this a proper op with PHI node
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]]
if self.opts.device != "HIP":
ops = tuple(op1+op2+op3)
else:
ops = (self.uop(UOps.CAST, dtypes._half16, tuple(op1)),
self.uop(UOps.CAST, dtypes._half16, tuple(op2)),
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
for z in range(cast(DType, ret.dtype).sz):
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx)
i += wmma_sz[2]
else:
if locals_to_store:
@@ -309,10 +313,9 @@ class Linearizer(Kernel):
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
# run early AST (with reduce)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx)
# end the reduce loop
end_loop(reduce_idxs)
self.load_cache.clear()
# end the local loop, do the local reduce
@@ -320,7 +323,6 @@ class Linearizer(Kernel):
fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
self.uop(UOps.BARRIER, None, (), cachable=False)
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
if self.opts.has_local:
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
@@ -356,24 +358,23 @@ class Linearizer(Kernel):
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
# end the late reduce loop
end_loop(end_local_idxs)
self.load_cache.clear()
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
# run late AST
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx)
# store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# end the global (and maybe local) loop
# end the if statement if we used it
if if_gate: self.uop(UOps.END, None, (if_gate,))
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
# (recursively) remove childless uops
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
while 1:
has_child: Set[UOp] = set()
for ru in self.uops:
@@ -384,7 +385,28 @@ class Linearizer(Kernel):
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
def get_recursive_deps(x:UOp) -> List[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return sorted(list(deps), key=lambda x: x.num)
# add END of loops after the last thing that (recursively) depends on them
for u in self.uops:
if u.uop == UOps.LOOP:
last_phi = self.uops.index(get_recursive_deps(u)[-1])
at_end = self.uops[last_phi+1:]
self.uops = self.uops[:last_phi+1]
self.uop(UOps.END, None, (u,), cachable=False)
self.uops += at_end
# maybe graph the uops
if DEBUG >= 5:
for u in self.uops: print(u)
if getenv("GRAPHUOPS"):
from tinygrad.graph import graph_uops
graph_uops(self.uops)
@@ -415,7 +437,7 @@ class Linearizer(Kernel):
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
if DEBUG >= 5: print(self.uops[-1])
#if DEBUG >= 5: print(self.uops[-1])
if cachable: self.saved_exprs[key] = self.uops[-1]
return self.uops[-1]

View File

@@ -130,8 +130,10 @@ class dtypes:
# NOTE: these are internal dtypes, should probably check for that
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_half16: Final[DType] = DType(0, 2*16, "half16", None, 16)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_float8: Final[DType] = DType(4, 4*8, "float8", None, 8)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# NOTE: these are image dtypes

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
import math
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
@@ -46,6 +46,8 @@ class CStyleLanguage(NamedTuple):
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._half16: return f"{{{','.join(f'(half){x}' for x in x)}}}"
if var_dtype == dtypes._float8: return f"{{{','.join(x)}}}"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
@@ -141,21 +143,19 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
kk("}")
elif uop == UOps.WMMA:
if args[0] == "METAL":
assert dtype == dtypes._float2, "output dtype of METAL TC is _float2"
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
output = ssa(u, 'wmma')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};")
kk("{ simdgroup_float8x8 a,b,c;")
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}")
kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
elif args[0] == "HIP":
kk("{")
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};")
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
kk("}")
assert dtype == dtypes._float8, "output dtype of HIP TC is _float8"
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
else:
raise NotImplementedError(f"WMMA not implemented for {args}")
elif uop == UOps.ALU:
@@ -205,7 +205,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
bufs.append(args)
r[u] = args[0]
elif uop == UOps.GEP:
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
if cast(DType, vin[0].dtype).sz > 4:
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
else:
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
else:
raise RuntimeError(f"failed to render {uop}")