mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
less wmma args (#11385)
* less wmma args * scalar * ops_python * mypy * lint * dedup * helper wmma_args --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -75,6 +75,10 @@ extra_pm = PatternMatcher([
|
||||
|
||||
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
|
||||
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
|
||||
def wmma_args(uops:list[UOp]):
|
||||
return dedup((uop.arg[0], uop.arg[1], uop.src[0].dtype.scalar(), uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA)
|
||||
|
||||
class CStyleLanguage(Renderer):
|
||||
kernel_typedef: str = "void"
|
||||
buffer_prefix: str = ""
|
||||
@@ -210,7 +214,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
def _render_defines(self, uops) -> list[str]:
|
||||
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
|
||||
# https://github.com/corsix/amx
|
||||
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
||||
for name, (N, M, _), dtype_in, _, _, _, _, _ in wmma_args(uops):
|
||||
prefix += [
|
||||
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
|
||||
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
||||
@@ -270,9 +274,9 @@ class IntelRenderer(OpenCLRenderer):
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = []
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
||||
dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16")
|
||||
prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{
|
||||
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops):
|
||||
dt_in = ("ushort", "bf16") if dtype_in == dtypes.bfloat16 else (dtype_in.name, "f16")
|
||||
prefix.append(f"""{dtype_out.name}8 __{name}({dt_in[0]}16 a, {dt_in[0]}16 b, {dtype_out.name}8 c) {{
|
||||
return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix or None)
|
||||
|
||||
@@ -308,13 +312,13 @@ class MetalRenderer(CStyleLanguage):
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
|
||||
for arg in wmma_args: prefix.append(
|
||||
f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{
|
||||
simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c;
|
||||
prefix = ["#include <metal_stdlib>","using namespace metal;"]
|
||||
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): prefix.append(
|
||||
f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{
|
||||
simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c;
|
||||
mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
|
||||
mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
|
||||
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
||||
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
_nms = "xyzwabcdefghijkl"
|
||||
@@ -363,7 +367,7 @@ class CUDARenderer(CStyleLanguage):
|
||||
|
||||
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
||||
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
|
||||
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
||||
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
||||
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
|
||||
@@ -458,15 +462,15 @@ class AMDRenderer(CStyleLanguage):
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
||||
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if self.tensor_cores == tc.amd_cdna:
|
||||
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if arg[2] == dtypes.half else 'bf16_1k'}")
|
||||
prefix.append(f"#define __{name} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if dtype_in == dtypes.half else 'bf16_1k'}")
|
||||
# #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12
|
||||
elif self.tensor_cores == tc.amd_rdna4:
|
||||
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_{type_map[arg[3]]}_16x16x16_{type_map[arg[2]]}_w32_gfx12")
|
||||
elif arg[3] == dtypes.float:
|
||||
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if arg[2] == dtypes.half else 'bf16'}_w32")
|
||||
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
||||
prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12")
|
||||
elif dtype_out == dtypes.float:
|
||||
prefix.append(f"#define __{name} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if dtype_in == dtypes.half else 'bf16'}_w32")
|
||||
else: prefix.append(f"static inline __attribute__((device)) half8 __{name}"+"""(half16 a, half16 b, half8 c) {
|
||||
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
||||
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
|
||||
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
||||
|
||||
@@ -207,9 +207,9 @@ class PTXRenderer(Renderer):
|
||||
elif u.op is Ops.DEFINE_GLOBAL: bufs.append((f"data{u.arg}", u.dtype))
|
||||
elif u.op is Ops.WMMA:
|
||||
# registers for packing/unpacking input and acc
|
||||
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.arg[2].itemsize)],
|
||||
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.arg[2].itemsize)],
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.arg[3].itemsize)]]
|
||||
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.dtype.scalar().itemsize)],
|
||||
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.dtype.scalar().itemsize)],
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.src[0].dtype.scalar().itemsize)]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
|
||||
|
||||
@@ -123,24 +123,27 @@ class PythonProgram:
|
||||
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
||||
return out
|
||||
|
||||
first_src_dtype = self.uops[idp[0]][1]
|
||||
assert isinstance(first_src_dtype, DType) # mypy
|
||||
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
|
||||
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
||||
if arg[4] == "METAL":
|
||||
if device == "METAL":
|
||||
# A (2 elements on 32 threads): row major
|
||||
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
||||
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
||||
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
||||
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif arg[4] == "AMD" and arg[5] == 64:
|
||||
elif device == "AMD" and threads == 64:
|
||||
def a_elem(x, k, row, goff): return x[k%4][goff + (k//4)*16 + row]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
|
||||
ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "AMD" and len(inp[0]) == 8: # RDNA4
|
||||
elif device == "AMD" and len(inp[0]) == 8: # RDNA4
|
||||
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
|
||||
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "AMD":
|
||||
elif device == "AMD":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, k, row, goff):
|
||||
assert x[k][goff+row] == x[k][goff+row+16], "warp elements not duplicated properly across lanes"
|
||||
@@ -149,27 +152,27 @@ class PythonProgram:
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CUDA":
|
||||
elif device == "CUDA":
|
||||
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
|
||||
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
|
||||
|
||||
if arg[1] == (8,16,16):
|
||||
if dims == (8,16,16):
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif arg[1] == (8,16,8) and arg[2] == dtypes.half:
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.half:
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif arg[1] == (8,16,8) and arg[2] == dtypes.float:
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.float:
|
||||
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif arg[4] == "INTEL":
|
||||
elif device == "INTEL":
|
||||
# A (16 elements on 8 threads)
|
||||
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
|
||||
# B (16 elements on 8 threads)
|
||||
@@ -177,7 +180,7 @@ class PythonProgram:
|
||||
# C, D (8 elements on 8 threads)
|
||||
def c_map(lane, elem): return (lane, elem)
|
||||
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CPU":
|
||||
elif device == "CPU":
|
||||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
|
||||
Reference in New Issue
Block a user