mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
proper wmma (#2245)
* proper wmma * hip cast * bugfixes * bugfix * that bug is fixed --------- Co-authored-by: George Hotz <george@tinygrad.org>
This commit is contained in:
@@ -27,13 +27,19 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: tests
|
- id: tests
|
||||||
name: subset of (CPU) tests
|
name: subset of (CPU) tests
|
||||||
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py test/external/test_example.py
|
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
|
||||||
|
language: system
|
||||||
|
always_run: true
|
||||||
|
pass_filenames: false
|
||||||
|
- id: example
|
||||||
|
name: multi device tests
|
||||||
|
entry: python3 test/external/test_example.py
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: pylint
|
- id: pylint
|
||||||
name: pylint
|
name: pylint
|
||||||
entry: python -m pylint tinygrad/
|
entry: python3 -m pylint tinygrad/
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|||||||
@@ -67,13 +67,13 @@ with open("/tmp/cc2.elf", "wb") as f:
|
|||||||
f.write(asm)
|
f.write(asm)
|
||||||
|
|
||||||
print(colored("creating CLProgram", "green"))
|
print(colored("creating CLProgram", "green"))
|
||||||
prg = CLProgram("code", asm, binary=True)
|
prg = CLProgram("code", asm)
|
||||||
|
|
||||||
print(colored("running program", "green"))
|
print(colored("running program", "green"))
|
||||||
G = 512
|
G = 512
|
||||||
FLOPS *= 100000*G*G # loop * global_size
|
FLOPS *= 100000*G*G # loop * global_size
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
tm = prg([G//256, G], [256, 1], buf, wait=True)
|
tm = prg(buf, global_size=[G//256, G, 1], local_size=[256, 1, 1], wait=True)
|
||||||
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
|
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
|
||||||
|
|
||||||
print(colored("transferring buffer", "green"))
|
print(colored("transferring buffer", "green"))
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
|||||||
lin = Linearizer(ast)
|
lin = Linearizer(ast)
|
||||||
assert fuzz_linearizer(lin) != "PASS"
|
assert fuzz_linearizer(lin) != "PASS"
|
||||||
|
|
||||||
@unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU", "CLANG"], "fails on these backends")
|
@unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU"], "fails on these backends")
|
||||||
def test_failure_3(self):
|
def test_failure_3(self):
|
||||||
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1))
|
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1))
|
||||||
lin = Linearizer(ast)
|
lin = Linearizer(ast)
|
||||||
|
|||||||
@@ -209,15 +209,11 @@ class Linearizer(Kernel):
|
|||||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
|
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
|
||||||
self.loop_uops.update(new_loops)
|
self.loop_uops.update(new_loops)
|
||||||
return tuple(new_loops.values())
|
return tuple(new_loops.values())
|
||||||
def end_loop(xx:List[Variable]):
|
|
||||||
for x in xx[::-1]:
|
|
||||||
if not isinstance(x, NumNode) and x.expr is not None:
|
|
||||||
loop_uop = self.loop_uops[x.expr]
|
|
||||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
|
|
||||||
|
|
||||||
# set global/local size
|
# set global/local size
|
||||||
self.global_size: Optional[List[int]] = None
|
self.global_size: Optional[List[int]] = None
|
||||||
self.local_size: Optional[List[int]] = None
|
self.local_size: Optional[List[int]] = None
|
||||||
|
global_loop_ctx: Tuple[UOp, ...] = tuple()
|
||||||
if self.dont_use_locals:
|
if self.dont_use_locals:
|
||||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||||
@@ -228,7 +224,7 @@ class Linearizer(Kernel):
|
|||||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
||||||
else:
|
else:
|
||||||
render_loop(loop_global_idxs+loop_local_idxs)
|
global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs)
|
||||||
|
|
||||||
# parse AST
|
# parse AST
|
||||||
loaded_buffers = {}
|
loaded_buffers = {}
|
||||||
@@ -296,8 +292,16 @@ class Linearizer(Kernel):
|
|||||||
for y in range(by):
|
for y in range(by):
|
||||||
for x in range(bx):
|
for x in range(bx):
|
||||||
for j in range(acc_reds):
|
for j in range(acc_reds):
|
||||||
# TODO: make this a proper op with PHI node
|
op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]]
|
||||||
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
if self.opts.device != "HIP":
|
||||||
|
ops = tuple(op1+op2+op3)
|
||||||
|
else:
|
||||||
|
ops = (self.uop(UOps.CAST, dtypes._half16, tuple(op1)),
|
||||||
|
self.uop(UOps.CAST, dtypes._half16, tuple(op2)),
|
||||||
|
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
|
||||||
|
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||||
|
for z in range(cast(DType, ret.dtype).sz):
|
||||||
|
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx)
|
||||||
i += wmma_sz[2]
|
i += wmma_sz[2]
|
||||||
else:
|
else:
|
||||||
if locals_to_store:
|
if locals_to_store:
|
||||||
@@ -309,10 +313,9 @@ class Linearizer(Kernel):
|
|||||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
||||||
|
|
||||||
# run early AST (with reduce)
|
# run early AST (with reduce)
|
||||||
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx)
|
||||||
|
|
||||||
# end the reduce loop
|
# end the reduce loop
|
||||||
end_loop(reduce_idxs)
|
|
||||||
self.load_cache.clear()
|
self.load_cache.clear()
|
||||||
|
|
||||||
# end the local loop, do the local reduce
|
# end the local loop, do the local reduce
|
||||||
@@ -320,7 +323,6 @@ class Linearizer(Kernel):
|
|||||||
fake_global_idxs = [x*0 for x in global_idxs]
|
fake_global_idxs = [x*0 for x in global_idxs]
|
||||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||||
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
|
|
||||||
if self.opts.has_local:
|
if self.opts.has_local:
|
||||||
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
||||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||||
@@ -356,24 +358,23 @@ class Linearizer(Kernel):
|
|||||||
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
|
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
|
||||||
|
|
||||||
# end the late reduce loop
|
# end the late reduce loop
|
||||||
end_loop(end_local_idxs)
|
|
||||||
self.load_cache.clear()
|
self.load_cache.clear()
|
||||||
|
|
||||||
# load latebufs
|
# load latebufs
|
||||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||||
|
|
||||||
# run late AST
|
# run late AST
|
||||||
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
|
val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx)
|
||||||
|
|
||||||
# store
|
# store
|
||||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||||
|
|
||||||
# end the global (and maybe local) loop
|
# end the if statement if we used it
|
||||||
if if_gate: self.uop(UOps.END, None, (if_gate,))
|
if if_gate: self.uop(UOps.END, None, (if_gate,))
|
||||||
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
|
|
||||||
|
|
||||||
# (recursively) remove childless uops
|
# (recursively) remove childless uops
|
||||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
||||||
|
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
||||||
while 1:
|
while 1:
|
||||||
has_child: Set[UOp] = set()
|
has_child: Set[UOp] = set()
|
||||||
for ru in self.uops:
|
for ru in self.uops:
|
||||||
@@ -384,7 +385,28 @@ class Linearizer(Kernel):
|
|||||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||||
self.uops = nu
|
self.uops = nu
|
||||||
|
|
||||||
|
def get_recursive_deps(x:UOp) -> List[UOp]:
|
||||||
|
deps = set([x])
|
||||||
|
ssize = 0
|
||||||
|
while ssize != len(deps):
|
||||||
|
ssize = len(deps)
|
||||||
|
for u in self.uops:
|
||||||
|
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
||||||
|
deps.add(u)
|
||||||
|
return sorted(list(deps), key=lambda x: x.num)
|
||||||
|
|
||||||
|
# add END of loops after the last thing that (recursively) depends on them
|
||||||
|
for u in self.uops:
|
||||||
|
if u.uop == UOps.LOOP:
|
||||||
|
last_phi = self.uops.index(get_recursive_deps(u)[-1])
|
||||||
|
at_end = self.uops[last_phi+1:]
|
||||||
|
self.uops = self.uops[:last_phi+1]
|
||||||
|
self.uop(UOps.END, None, (u,), cachable=False)
|
||||||
|
self.uops += at_end
|
||||||
|
|
||||||
# maybe graph the uops
|
# maybe graph the uops
|
||||||
|
if DEBUG >= 5:
|
||||||
|
for u in self.uops: print(u)
|
||||||
if getenv("GRAPHUOPS"):
|
if getenv("GRAPHUOPS"):
|
||||||
from tinygrad.graph import graph_uops
|
from tinygrad.graph import graph_uops
|
||||||
graph_uops(self.uops)
|
graph_uops(self.uops)
|
||||||
@@ -415,7 +437,7 @@ class Linearizer(Kernel):
|
|||||||
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||||
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
||||||
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
|
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
|
||||||
if DEBUG >= 5: print(self.uops[-1])
|
#if DEBUG >= 5: print(self.uops[-1])
|
||||||
if cachable: self.saved_exprs[key] = self.uops[-1]
|
if cachable: self.saved_exprs[key] = self.uops[-1]
|
||||||
return self.uops[-1]
|
return self.uops[-1]
|
||||||
|
|
||||||
|
|||||||
@@ -130,8 +130,10 @@ class dtypes:
|
|||||||
# NOTE: these are internal dtypes, should probably check for that
|
# NOTE: these are internal dtypes, should probably check for that
|
||||||
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
|
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
|
||||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||||
|
_half16: Final[DType] = DType(0, 2*16, "half16", None, 16)
|
||||||
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||||
|
_float8: Final[DType] = DType(4, 4*8, "float8", None, 8)
|
||||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||||
|
|
||||||
# NOTE: these are image dtypes
|
# NOTE: these are image dtypes
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict
|
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from tinygrad.codegen.linearizer import UOps, UOp
|
from tinygrad.codegen.linearizer import UOps, UOp
|
||||||
@@ -46,6 +46,8 @@ class CStyleLanguage(NamedTuple):
|
|||||||
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
|
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
|
||||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||||
assert self.float4 is not None, "cast is not supported on this platform"
|
assert self.float4 is not None, "cast is not supported on this platform"
|
||||||
|
if var_dtype == dtypes._half16: return f"{{{','.join(f'(half){x}' for x in x)}}}"
|
||||||
|
if var_dtype == dtypes._float8: return f"{{{','.join(x)}}}"
|
||||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||||
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||||
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
|
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
|
||||||
@@ -141,21 +143,19 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||||||
kk("}")
|
kk("}")
|
||||||
elif uop == UOps.WMMA:
|
elif uop == UOps.WMMA:
|
||||||
if args[0] == "METAL":
|
if args[0] == "METAL":
|
||||||
|
assert dtype == dtypes._float2, "output dtype of METAL TC is _float2"
|
||||||
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
||||||
|
output = ssa(u, 'wmma')
|
||||||
|
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};")
|
||||||
kk("{ simdgroup_float8x8 a,b,c;")
|
kk("{ simdgroup_float8x8 a,b,c;")
|
||||||
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
|
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
|
||||||
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
|
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
|
||||||
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
|
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
|
||||||
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
||||||
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}")
|
kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
|
||||||
elif args[0] == "HIP":
|
elif args[0] == "HIP":
|
||||||
kk("{")
|
assert dtype == dtypes._float8, "output dtype of HIP TC is _float8"
|
||||||
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};")
|
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||||
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
|
|
||||||
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
|
|
||||||
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
|
|
||||||
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
|
|
||||||
kk("}")
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||||
elif uop == UOps.ALU:
|
elif uop == UOps.ALU:
|
||||||
@@ -205,7 +205,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||||||
bufs.append(args)
|
bufs.append(args)
|
||||||
r[u] = args[0]
|
r[u] = args[0]
|
||||||
elif uop == UOps.GEP:
|
elif uop == UOps.GEP:
|
||||||
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
|
if cast(DType, vin[0].dtype).sz > 4:
|
||||||
|
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
|
||||||
|
else:
|
||||||
|
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"failed to render {uop}")
|
raise RuntimeError(f"failed to render {uop}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user