mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 06:34:03 -05:00
proper wmma (#2245)
* proper wmma * hip cast * bugfixes * bugfix * that bug is fixed --------- Co-authored-by: George Hotz <george@tinygrad.org>
This commit is contained in:
@@ -27,13 +27,19 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: tests
|
||||
name: subset of (CPU) tests
|
||||
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py test/external/test_example.py
|
||||
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: example
|
||||
name: multi device tests
|
||||
entry: python3 test/external/test_example.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: pylint
|
||||
name: pylint
|
||||
entry: python -m pylint tinygrad/
|
||||
entry: python3 -m pylint tinygrad/
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -67,13 +67,13 @@ with open("/tmp/cc2.elf", "wb") as f:
|
||||
f.write(asm)
|
||||
|
||||
print(colored("creating CLProgram", "green"))
|
||||
prg = CLProgram("code", asm, binary=True)
|
||||
prg = CLProgram("code", asm)
|
||||
|
||||
print(colored("running program", "green"))
|
||||
G = 512
|
||||
FLOPS *= 100000*G*G # loop * global_size
|
||||
for i in range(3):
|
||||
tm = prg([G//256, G], [256, 1], buf, wait=True)
|
||||
tm = prg(buf, global_size=[G//256, G, 1], local_size=[256, 1, 1], wait=True)
|
||||
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
|
||||
|
||||
print(colored("transferring buffer", "green"))
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
lin = Linearizer(ast)
|
||||
assert fuzz_linearizer(lin) != "PASS"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU", "CLANG"], "fails on these backends")
|
||||
@unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU"], "fails on these backends")
|
||||
def test_failure_3(self):
|
||||
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1))
|
||||
lin = Linearizer(ast)
|
||||
|
||||
@@ -209,15 +209,11 @@ class Linearizer(Kernel):
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
|
||||
self.loop_uops.update(new_loops)
|
||||
return tuple(new_loops.values())
|
||||
def end_loop(xx:List[Variable]):
|
||||
for x in xx[::-1]:
|
||||
if not isinstance(x, NumNode) and x.expr is not None:
|
||||
loop_uop = self.loop_uops[x.expr]
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
|
||||
|
||||
# set global/local size
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
global_loop_ctx: Tuple[UOp, ...] = tuple()
|
||||
if self.dont_use_locals:
|
||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
@@ -228,7 +224,7 @@ class Linearizer(Kernel):
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
||||
else:
|
||||
render_loop(loop_global_idxs+loop_local_idxs)
|
||||
global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs)
|
||||
|
||||
# parse AST
|
||||
loaded_buffers = {}
|
||||
@@ -296,8 +292,16 @@ class Linearizer(Kernel):
|
||||
for y in range(by):
|
||||
for x in range(bx):
|
||||
for j in range(acc_reds):
|
||||
# TODO: make this a proper op with PHI node
|
||||
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]]
|
||||
if self.opts.device != "HIP":
|
||||
ops = tuple(op1+op2+op3)
|
||||
else:
|
||||
ops = (self.uop(UOps.CAST, dtypes._half16, tuple(op1)),
|
||||
self.uop(UOps.CAST, dtypes._half16, tuple(op2)),
|
||||
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
|
||||
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
for z in range(cast(DType, ret.dtype).sz):
|
||||
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx)
|
||||
i += wmma_sz[2]
|
||||
else:
|
||||
if locals_to_store:
|
||||
@@ -309,10 +313,9 @@ class Linearizer(Kernel):
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
||||
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx)
|
||||
|
||||
# end the reduce loop
|
||||
end_loop(reduce_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# end the local loop, do the local reduce
|
||||
@@ -320,7 +323,6 @@ class Linearizer(Kernel):
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
@@ -356,24 +358,23 @@ class Linearizer(Kernel):
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
|
||||
|
||||
# end the late reduce loop
|
||||
end_loop(end_local_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||
|
||||
# run late AST
|
||||
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
|
||||
val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx)
|
||||
|
||||
# store
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
|
||||
# end the global (and maybe local) loop
|
||||
# end the if statement if we used it
|
||||
if if_gate: self.uop(UOps.END, None, (if_gate,))
|
||||
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
|
||||
|
||||
# (recursively) remove childless uops
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
||||
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in self.uops:
|
||||
@@ -384,7 +385,28 @@ class Linearizer(Kernel):
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
|
||||
def get_recursive_deps(x:UOp) -> List[UOp]:
|
||||
deps = set([x])
|
||||
ssize = 0
|
||||
while ssize != len(deps):
|
||||
ssize = len(deps)
|
||||
for u in self.uops:
|
||||
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
||||
deps.add(u)
|
||||
return sorted(list(deps), key=lambda x: x.num)
|
||||
|
||||
# add END of loops after the last thing that (recursively) depends on them
|
||||
for u in self.uops:
|
||||
if u.uop == UOps.LOOP:
|
||||
last_phi = self.uops.index(get_recursive_deps(u)[-1])
|
||||
at_end = self.uops[last_phi+1:]
|
||||
self.uops = self.uops[:last_phi+1]
|
||||
self.uop(UOps.END, None, (u,), cachable=False)
|
||||
self.uops += at_end
|
||||
|
||||
# maybe graph the uops
|
||||
if DEBUG >= 5:
|
||||
for u in self.uops: print(u)
|
||||
if getenv("GRAPHUOPS"):
|
||||
from tinygrad.graph import graph_uops
|
||||
graph_uops(self.uops)
|
||||
@@ -415,7 +437,7 @@ class Linearizer(Kernel):
|
||||
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
||||
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
|
||||
if DEBUG >= 5: print(self.uops[-1])
|
||||
#if DEBUG >= 5: print(self.uops[-1])
|
||||
if cachable: self.saved_exprs[key] = self.uops[-1]
|
||||
return self.uops[-1]
|
||||
|
||||
|
||||
@@ -130,8 +130,10 @@ class dtypes:
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_half16: Final[DType] = DType(0, 2*16, "half16", None, 16)
|
||||
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
_float8: Final[DType] = DType(4, 4*8, "float8", None, 8)
|
||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
@@ -46,6 +46,8 @@ class CStyleLanguage(NamedTuple):
|
||||
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
|
||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "cast is not supported on this platform"
|
||||
if var_dtype == dtypes._half16: return f"{{{','.join(f'(half){x}' for x in x)}}}"
|
||||
if var_dtype == dtypes._float8: return f"{{{','.join(x)}}}"
|
||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
|
||||
@@ -141,21 +143,19 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
||||
kk("}")
|
||||
elif uop == UOps.WMMA:
|
||||
if args[0] == "METAL":
|
||||
assert dtype == dtypes._float2, "output dtype of METAL TC is _float2"
|
||||
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
||||
output = ssa(u, 'wmma')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};")
|
||||
kk("{ simdgroup_float8x8 a,b,c;")
|
||||
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
|
||||
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
|
||||
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
|
||||
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
||||
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}")
|
||||
kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
|
||||
elif args[0] == "HIP":
|
||||
kk("{")
|
||||
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};")
|
||||
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
|
||||
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
|
||||
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
|
||||
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
|
||||
kk("}")
|
||||
assert dtype == dtypes._float8, "output dtype of HIP TC is _float8"
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
else:
|
||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
elif uop == UOps.ALU:
|
||||
@@ -205,7 +205,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
||||
bufs.append(args)
|
||||
r[u] = args[0]
|
||||
elif uop == UOps.GEP:
|
||||
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
|
||||
if cast(DType, vin[0].dtype).sz > 4:
|
||||
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
|
||||
else:
|
||||
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
|
||||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user