mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
wmma: add CUDA tensor core and fix test_speed_v_torch failure (#3544)
This commit is contained in:
@@ -37,6 +37,7 @@ class CUDACompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024])
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
CUDACompiler.linearizer_opts = CUDACompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
|
||||
super().__init__(f"compile_cuda_{self.arch}")
|
||||
def render(self, name:str, uops) -> str: return CUDARenderer(name, uops)
|
||||
def compile(self, src:str) -> bytes:
|
||||
|
||||
@@ -191,6 +191,11 @@ class PythonProgram:
|
||||
return a_elem(x, j, i, goff)
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg == '__cuda_mma_m16n8k16_f16_f32':
|
||||
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # A (8 elements on 32 threads)
|
||||
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
|
||||
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
else:
|
||||
raise Exception(f"unimplemented tensor core {arg}")
|
||||
elif uop is UOps.ALU:
|
||||
@@ -203,7 +208,8 @@ class PythonProgram:
|
||||
|
||||
class PythonCompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
|
||||
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else LinearizerOptions("PYTHON"))
|
||||
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \
|
||||
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON")))
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
|
||||
Reference in New Issue
Block a user