wmma: add CUDA tensor core and fix test_speed_v_torch failure (#3544)

This commit is contained in:
Francis Lam
2024-03-01 17:51:02 -08:00
committed by GitHub
parent b3cdc11a58
commit e17f1821a7
8 changed files with 50 additions and 25 deletions

View File

@@ -85,6 +85,8 @@ jobs:
run: CUDA=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Run Tensor Core GEMM
run: CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run GPT2
run: |
CUDA=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
@@ -99,6 +101,7 @@ jobs:
path: |
onnx_inference_speed.csv
torch_speed.txt
matmul.txt
gpt2_unjitted.txt
gpt2_jitted.txt
gpt2_half_beam.txt

View File

@@ -45,6 +45,7 @@ jobs:
run: |
DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
- name: Test dtype with Python emulator
run: DEBUG=2 PYTHON=1 python3 test/test_dtype.py
- name: Test ops with Python emulator

View File

@@ -30,7 +30,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
dtype_out: DType # dtype for C and D
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
wmma_func: str # name of wmma function to call
def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>"
def num_threads(self): return len(self.threads)
@@ -47,14 +47,17 @@ class TensorCoreOptions(NamedTuple):
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
],
"HIP": [
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
]
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
],
"CUDA": [
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
],
}
class LocalBuffer(NamedTuple):

View File

@@ -172,7 +172,7 @@ class Linearizer(Kernel):
# late alias the tensor core buffers
if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1]*tc.num_upcasts() + [3]*(self.upcasted-tc.num_upcasts()) # noqa: E501
for tc_buf in tc_opts.bufs:
self.alias_buffer(tc_buf, alias_pattern)
@@ -263,27 +263,24 @@ class Linearizer(Kernel):
out_buf = -1 if self.group_for_reduces else 0
acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
# reduce loop
loop_ctx = render_loop(reduce_idxs)
if (tc:=self.tensor_core):
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
replace_idxs = []
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
for s in local_sizes:
thread_idxs.append(thread_idx % s)
thread_idx //= s
for alias in aliases:
full_var, full_var_sz = NumNode(0), 1
if alias[0] != 0:
for i in alias:
next_var = local_idxs[-i] if i > 0 else Variable("_uidx_tc", 0, local_size-1)
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
full_var += next_var * full_var_sz
full_var_sz *= next_var.max+1
replace_idxs.append(full_var)
return replace_idxs
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
for n in range(len(tc.threads)):
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
for n in range(len(replace_acc_idxs)-len(tc.threads)):
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print("store alias: idxs=", global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
# reduce loop
loop_ctx = render_loop(reduce_idxs)
# compute local aliases
locals_to_store = []
@@ -297,13 +294,20 @@ class Linearizer(Kernel):
buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals
for n in range(tc.num_upcasts()):
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
ll = self.global_load(i, buf_idxs)
locals_to_store.append((localbuf_idx, buf_idxs, ll))
# copy in any global buffers
if (tc:=self.tensor_core):
wmma_sz = tc.thread_local_sizes
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
for n in range(len(tc.threads)):
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
for n in range(len(replace_acc_idxs)-len(tc.threads)):
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
def upcast_strides(buf:int):
strides, next = [], 1
for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:

View File

@@ -220,6 +220,7 @@ class UOpGraph:
if u.uop is UOps.WMMA:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
else: raise Exception("not implemented")
return flops, mem

View File

@@ -238,8 +238,14 @@ class CUDALanguage(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, local_size, uops, prefix=None):
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
if any(uop.dtype == dtypes.half for uop in uops):
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };",
"__device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }"]
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };", "struct half8 { half x, y, z, w, a, b, c, d; };",
"__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }",
"__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }",
"""__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;}""",
]
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <cuda_bf16.h>")
return super().render_kernel(function_name, kernel, bufs, local_size, uops, prefix=prefix)
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())

View File

@@ -37,6 +37,7 @@ class CUDACompiler(Compiler):
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024])
def __init__(self, arch:str):
self.arch = arch
CUDACompiler.linearizer_opts = CUDACompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
super().__init__(f"compile_cuda_{self.arch}")
def render(self, name:str, uops) -> str: return CUDARenderer(name, uops)
def compile(self, src:str) -> bytes:

View File

@@ -191,6 +191,11 @@ class PythonProgram:
return a_elem(x, j, i, goff)
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg == '__cuda_mma_m16n8k16_f16_f32':
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # A (8 elements on 32 threads)
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
else:
raise Exception(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
@@ -203,7 +208,8 @@ class PythonProgram:
class PythonCompiler(Compiler):
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else LinearizerOptions("PYTHON"))
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON")))
def render(self, name:str, uops:List[UOp]) -> str:
lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()