wmma: refactor to remove wmma_func and create TC funcs as needed (#3945)

* wmma: refactor to remove wmma_func and create TC funcs as needed

* test_linearizer: disable bf16 CUDA during emulation testing

* cstyle: clean up creation of CUDA vec dtypes

* extra/gemm: add option to accumulate to bfloat16

* cleanups

* benchmark: add CUDA bfloat16 matmul

* more cleanups
This commit is contained in:
Francis Lam
2024-03-27 13:43:09 -07:00
committed by GitHub
parent 88b24df40a
commit 7c5729a3bd
9 changed files with 101 additions and 92 deletions

View File

@@ -108,7 +108,9 @@ jobs:
- name: Test speed vs torch - name: Test speed vs torch
run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Run Tensor Core GEMM - name: Run Tensor Core GEMM
run: CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt run: |
CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
- name: Run LLaMA - name: Run LLaMA
run: | run: |
CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@@ -140,6 +142,7 @@ jobs:
onnx_inference_speed.csv onnx_inference_speed.csv
torch_speed.txt torch_speed.txt
matmul.txt matmul.txt
matmul_bfloat16.txt
llama_unjitted.txt llama_unjitted.txt
llama_jitted.txt llama_jitted.txt
llama_beam.txt llama_beam.txt

View File

@@ -1,8 +1,8 @@
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad import dtypes, Tensor from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.float dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else None acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
CNT = getenv("CNT", 8) CNT = getenv("CNT", 8)
BS = getenv("BS", 16) BS = getenv("BS", 16)
@@ -12,6 +12,8 @@ HW = getenv("HW", 128)
K = getenv("K", 3) K = getenv("K", 3)
PADDING = getenv("PADDING", 1) PADDING = getenv("PADDING", 1)
COMP = getenv("COMP", 0) COMP = getenv("COMP", 0)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
FLOPS = BS*K*K*CIN*HW*HW*COUT*2 FLOPS = BS*K*K*CIN*HW*HW*COUT*2
def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize() def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize()
@@ -28,4 +30,4 @@ if __name__ == "__main__":
torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu") torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")
ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device) ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device)
tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING) tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING)
np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=1e-4, rtol=3e-2) np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=ATOL, rtol=RTOL)

View File

@@ -1,10 +1,12 @@
import numpy as np import numpy as np
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad import dtypes, Tensor from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.float dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else None acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
N = getenv("N", 4096) N = getenv("N", 4096)
CNT = getenv("CNT", 10) CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
if __name__ == "__main__": if __name__ == "__main__":
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
@@ -14,4 +16,4 @@ if __name__ == "__main__":
c = a.matmul(b, acc_dtype=acc_dtype).realize() c = a.matmul(b, acc_dtype=acc_dtype).realize()
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
nc = c.numpy() nc = c.numpy()
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=3e-2) np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)

View File

@@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
from tinygrad.engine.jit import CacheCollector from tinygrad.engine.jit import CacheCollector
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.helpers import prod, Context from tinygrad.helpers import prod, Context, getenv
from tinygrad.dtype import DType, dtypes from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph from tinygrad.codegen.uops import UOpGraph
@@ -186,6 +186,7 @@ class TestLinearizer(unittest.TestCase):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores: if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores:
self.skipTest("device doesn't have tensor cores") self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.linearizer_opts.device]: for tc in tensor_cores[Device[Device.DEFAULT].compiler.linearizer_opts.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in) a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in)
np_a, np_b = a.numpy(), b.numpy() np_a, np_b = a.numpy(), b.numpy()
r = a.matmul(b, acc_dtype=tc.dtype_out) r = a.matmul(b, acc_dtype=tc.dtype_out)

View File

@@ -30,16 +30,14 @@ class Opt:
@dataclass(frozen=True) @dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N) class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
dims: List[int] # N, M, K dims: Tuple[int,int,int] # N, M, K
dtype_in: DType # dtype for A and B dtype_in: DType # dtype for A and B
dtype_out: DType # dtype for C and D dtype_out: DType # dtype for C and D
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501 thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
wmma_func: str # name of wmma function to call def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>" def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
def num_threads(self): return len(self.threads)
def num_upcasts(self): return len(self.thread_local_aliases[0]) - self.num_threads()
class TensorCoreOptions(NamedTuple): class TensorCoreOptions(NamedTuple):
bufs: Tuple[int, int] # the local aliased buffers for A and B bufs: Tuple[int, int] # the local aliased buffers for A and B
@@ -51,18 +49,9 @@ class TensorCoreOptions(NamedTuple):
elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
tensor_cores: Dict[str, List[TensorCore]] = { tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [ "METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 "HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 "CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)]], # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
],
"HSA": [
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501
],
"CUDA": [
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
],
} }
class LocalBuffer(NamedTuple): class LocalBuffer(NamedTuple):

View File

@@ -291,10 +291,10 @@ class Linearizer(Kernel):
if (tc:=self.tensor_core): if (tc:=self.tensor_core):
min_alias_idx = min(self.local_alias.keys()) min_alias_idx = min(self.local_alias.keys())
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx]) replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
for n in range(tc.num_threads()): for n in range(len(tc.threads)):
buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals buf_idxs[self.first_reduce-len(tc.threads)+n] = replace_input_idxs[n] # replace locals
for n in range(tc.num_upcasts()): for n in range(tc.num_upcasts()):
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}") if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
ll = self.global_load(i, buf_idxs) ll = self.global_load(i, buf_idxs)
locals_to_store.append((localbuf_idx, buf_idxs, ll)) locals_to_store.append((localbuf_idx, buf_idxs, ll))
@@ -315,13 +315,13 @@ class Linearizer(Kernel):
strides.append((0 if stride == 0 else next, sz)) strides.append((0 if stride == 0 else next, sz))
next *= 1 if stride == 0 else sz next *= 1 if stride == 0 else sz
return strides return strides
upcasts = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]] upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
for iter in [x[::-1] for x in [x for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]]: for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)] offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])), ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])), self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
self.uops.add(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]]))) self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
ret = self.uops.add(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func) ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev))
for z in range(wmma_sz[2]): for z in range(wmma_sz[2]):
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx) acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
else: else:

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import functools, math, operator, itertools import functools, math, operator, itertools
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
from collections import defaultdict from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten from tinygrad.helpers import DEBUG, flatten, prod
from tinygrad.dtype import dtypes, DType from tinygrad.dtype import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode
@@ -385,8 +385,6 @@ class UOpGraph:
assert u.vin[2].dtype is not None assert u.vin[2].dtype is not None
mem += u.vin[2].dtype.itemsize * mults mem += u.vin[2].dtype.itemsize * mults
elif u.uop is UOps.WMMA: elif u.uop is UOps.WMMA:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults assert u.arg[1] is not None
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults flops += 2 * prod(u.arg[1]) // 32 * mults
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
else: raise NotImplementedError(f"not implemented wmma {u.arg=}")
return flops, mem return flops, mem

View File

@@ -164,7 +164,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}" assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
bufs.append((args[1], (dtype,args[2]))) bufs.append((args[1], (dtype,args[2])))
r[u] = args[1] r[u] = args[1]
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop is UOps.GEP: elif uop is UOps.GEP:
@@ -215,11 +215,11 @@ class MetalLanguage(CStyleLanguage):
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix = ["#include <metal_stdlib>","using namespace metal;"] prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop == UOps.WMMA])
if any(uop.uop == UOps.WMMA for uop in uops): prefix.append("""template<typename T, typename S, typename U> U __metal_wmma(T m, T n, U o) { for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
S a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y; simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
return U(c.thread_elements()[0], c.thread_elements()[1]);\n}""") return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix) return super().render_kernel(function_name, kernel, bufs, uops, prefix)
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage()) MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
@@ -229,6 +229,11 @@ code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",} UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
_nms = "xyzwabcdefghijkl"
def _make_cuda_dtype(base_type, name, cnt):
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
class CUDALanguage(CStyleLanguage): class CUDALanguage(CStyleLanguage):
kernel_prefix = "extern \"C\" __global__ " kernel_prefix = "extern \"C\" __global__ "
smem_prefix = "__shared__ " smem_prefix = "__shared__ "
@@ -241,17 +246,24 @@ class CUDALanguage(CStyleLanguage):
type_map = {dtypes.bfloat16: "nv_bfloat16"} type_map = {dtypes.bfloat16: "nv_bfloat16"}
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"] prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
if any(uop.dtype == dtypes.half for uop in uops): if any(uop.dtype == dtypes.half for uop in uops):
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };", "struct half8 { half x, y, z, w, a, b, c, d; };", prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
"__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }",
"__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }", if any(uop.dtype == dtypes.bfloat16 for uop in uops):
"""__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b); prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
return c;}""",] for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]):
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;}}""")
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <cuda_bf16.h>")
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix) return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage()) CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
@@ -273,43 +285,31 @@ def _make_hip_code_for_op():
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() } return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
def _make_hip_dtype(base_type, name, cnt): def _make_hip_dtype(base_type, name, cnt):
nms = "xyzwabcdefghijkl"[:cnt] elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \ return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}(" + ', '.join([f"{base_type} {x}" for x in nms]) + \ f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
") { return {" + ', '.join(nms) + "}; }"
class HIPLanguage(CStyleLanguage): class HIPLanguage(CStyleLanguage):
kernel_prefix = """ kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
#define half _Float16 extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int); extern "C" {
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int); __attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int); __attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float);
__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float);
extern "C" { __attribute__((device)) float __ocml_sin_f32(float);
__attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float); __attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float);
__attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float); __attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double);
__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float); __attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double);
__attribute__((device)) float __ocml_sin_f32(float); __attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double);
__attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float); __attribute__((device)) double __ocml_sin_f64(double);
__attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double); __attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double);
__attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double); __attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);
__attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double); __attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16);
__attribute__((device)) double __ocml_sin_f64(double); __attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16);
__attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double); __attribute__((device)) _Float16 __ocml_sin_f16(_Float16);
__attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16); __attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16); }\nextern "C" __attribute__((global))"""
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16);
__attribute__((device)) _Float16 __ocml_sin_f16(_Float16);
__attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
}\n""" + '\n'.join([_make_hip_dtype(*x) for x in [
("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16),
("float", "float", 8)]]) + """
static __attribute__((device)) half8 __hip_wmma_f16_f16(half16 a, half16 b, half8 c) {
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;
}\nextern "C" __attribute__((global))"""
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})", code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"} "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = _make_hip_code_for_op() code_for_op = _make_hip_code_for_op()
@@ -321,8 +321,10 @@ class HIPLanguage(CStyleLanguage):
type_map = {dtypes.bfloat16: "hip_bfloat16"} type_map = {dtypes.bfloat16: "hip_bfloat16"}
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = ["#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))", prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
"typedef long unsigned int size_t;"] vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
# TODO: add BF16 vec dts
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append(""" if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
struct hip_bfloat16 { struct hip_bfloat16 {
unsigned short data; unsigned short data;
@@ -339,8 +341,19 @@ struct hip_bfloat16 {
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); } static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); } static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
""") """)
prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
("signed int", "int", 4), ("signed int", "int", 2)])) if any(uop.dtype == dtypes.half for uop in uops):
prefix.append("#define half _Float16")
vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
prefix += [_make_hip_dtype(*x) for x in vec_dts]
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix) return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def get_kernel_modifier(self, uops:UOpGraph) -> str: def get_kernel_modifier(self, uops:UOpGraph) -> str:

View File

@@ -145,13 +145,14 @@ class PythonProgram:
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K)) out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
return out return out
if arg.startswith('__metal_wmma'): # TODO: refactor these to a shared TensorCoreLayout in kernel.py
if arg[5] == "METAL":
# A (2 elements on 32 threads): row major # A (2 elements on 32 threads): row major
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B # (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16': elif arg[5] == "HSA":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, i, j, goff): def a_elem(x, i, j, goff):
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes" assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
@@ -160,7 +161,7 @@ class PythonProgram:
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg == '__cuda_mma_m16n8k16_f16_f32': elif arg[5] == "CUDA":
# A (8 elements on 32 threads) # A (8 elements on 32 threads)
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
# B (4 elements on 32 threads) # B (4 elements on 32 threads)