wmma: add CUDA tensor core and fix test_speed_v_torch failure (#3544)

This commit is contained in:
Francis Lam
2024-03-01 17:51:02 -08:00
committed by GitHub
parent b3cdc11a58
commit e17f1821a7
8 changed files with 50 additions and 25 deletions

View File

@@ -37,6 +37,7 @@ class CUDACompiler(Compiler):
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024])
def __init__(self, arch:str):
self.arch = arch
CUDACompiler.linearizer_opts = CUDACompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
super().__init__(f"compile_cuda_{self.arch}")
def render(self, name:str, uops) -> str: return CUDARenderer(name, uops)
def compile(self, src:str) -> bytes:

View File

@@ -191,6 +191,11 @@ class PythonProgram:
return a_elem(x, j, i, goff)
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg == '__cuda_mma_m16n8k16_f16_f32':
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # A (8 elements on 32 threads)
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
else:
raise Exception(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
@@ -203,7 +208,8 @@ class PythonProgram:
class PythonCompiler(Compiler):
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else LinearizerOptions("PYTHON"))
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON")))
def render(self, name:str, uops:List[UOp]) -> str:
lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()