Revert "wmma: add CUDA tensor core (#3464)" (#3474)

This reverts commit e9cef13f0b.
This commit is contained in:
George Hotz
2024-02-22 11:58:16 +01:00
committed by GitHub
parent e9cef13f0b
commit 7698781389
7 changed files with 27 additions and 49 deletions

View File

@@ -45,7 +45,6 @@ jobs:
run: |
DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
- name: Test dtype with Python emulator
run: DEBUG=2 PYTHON=1 python3 test/test_dtype.py
- name: Test ops with Python emulator

View File

@@ -30,7 +30,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
dtype_out: DType # dtype for C and D
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
wmma_func: str # name of wmma function to call
def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>"
def num_threads(self): return len(self.threads)
@@ -38,17 +38,14 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
],
"HIP": [
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
],
"CUDA": [
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,4),(1,4),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[0],[-1,1,-3],[2,3,-2]], [[0],[0],[2,3],[-1,1,-2],[0]], [[0],[0],[-1,1],[0],[2,3,-2]] ]), # noqa: E501
],
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
]
}
class LocalBuffer(NamedTuple):

View File

@@ -252,27 +252,30 @@ class Linearizer(Kernel):
# define accumulator
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
# reduce loop
loop_ctx = render_loop(reduce_idxs)
# barrier for fast GEMM
if (tc:=self.tensor_core):
self.uop(UOps.BARRIER, None, (), cachable=False)
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
for s in local_sizes:
thread_idxs.append(thread_idx % s)
thread_idx //= s
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
replace_idxs = []
for alias in aliases:
full_var, full_var_sz = NumNode(0), 1
if alias[0] != 0:
for i in alias:
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
next_var = local_idxs[-i] if i > 0 else Variable("_uidx_tc", 0, local_size-1)
full_var += next_var * full_var_sz
full_var_sz *= next_var.max+1
replace_idxs.append(full_var)
return replace_idxs
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
for n in range(len(tc.threads)):
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
for n in range(len(replace_acc_idxs)-len(tc.threads)):
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print("store alias: idxs=", global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
# reduce loop
loop_ctx = render_loop(reduce_idxs)
# barrier for fast GEMM
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
# compute local aliases
locals_to_store = []
@@ -286,20 +289,13 @@ class Linearizer(Kernel):
buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals
for n in range(tc.num_upcasts()):
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs=", buf_idxs)
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
ll = self.global_load(i, buf_idxs)
locals_to_store.append((localbuf_idx, buf_idxs, ll))
# copy in any global buffers
if (tc:=self.tensor_core):
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
for n in range(len(tc.threads)):
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
for n in range(len(replace_acc_idxs)-len(tc.threads)):
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
wmma_sz = tc.thread_local_sizes
def upcast_strides(buf:int):
strides, next = [], 1
for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:

View File

@@ -115,6 +115,5 @@ def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint
if u.uop is UOps.WMMA:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
else: raise Exception("not implemented")
return flops, mem

View File

@@ -236,14 +236,8 @@ class CUDALanguage(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, local_size, uops, prefix=None):
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
if any(uop.dtype == dtypes.half for uop in uops):
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };", "struct half8 { half x, y, z, w, a, b, c, d; };",
"__device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; }",
"__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; }",
"""__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;}""",
]
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };",
"__device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }"]
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <cuda_bf16.h>")
return super().render_kernel(function_name, kernel, bufs, local_size, uops, prefix=prefix)
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())

View File

@@ -33,7 +33,6 @@ class CUDACompiler(Compiler):
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024])
def __init__(self, arch:str):
self.arch = arch
CUDACompiler.linearizer_opts = CUDACompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
super().__init__(f"compile_cuda_{self.arch}")
def render(self, name:str, uops) -> str: return CUDARenderer(name, uops)
def compile(self, src:str) -> bytes:

View File

@@ -189,11 +189,6 @@ class PythonProgram:
return a_elem(x, j, i, goff)
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg == '__cuda_mma_m16n8k16_f16_f32':
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # A (8 elements on 32 threads)
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
else:
raise Exception(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
@@ -206,8 +201,7 @@ class PythonProgram:
class PythonCompiler(Compiler):
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions()))
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else LinearizerOptions())
def render(self, name:str, uops:List[UOp]) -> str:
lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()