From b0dd407cdd1d9e93e12874791788bd7014dc342c Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Fri, 11 Oct 2024 11:51:06 -0700 Subject: [PATCH] ops_cuda: add optional dynamic smem parameter (#6956) * ops_cuda: add optional dynamic smem parameter This is required to enable larger than 48kb shared memory usage on a per-kernel basis. * move setting max dynamic smem size to init --- tinygrad/runtime/ops_cuda.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 7a004e20ca..33288073f1 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -31,8 +31,8 @@ def cu_time_execution(cb, enable=False) -> Optional[float]: return ret.value * 1e-3 class CUDAProgram: - def __init__(self, device:CUDADevice, name:str, lib:bytes): - self.device, self.name, self.lib = device, name, lib + def __init__(self, device:CUDADevice, name:str, lib:bytes, smem:int=0): + self.device, self.name, self.lib, self.smem = device, name, lib, smem if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))])) if DEBUG >= 6: cuda_disassemble(lib, device.arch) @@ -45,6 +45,7 @@ class CUDAProgram: raise RuntimeError(f"module load failed with status code {status}: {cuda.cudaError_enum__enumvalues[status]}") check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8"))) self.prg = prg #type: ignore + if self.smem > 0: check(cuda.cuFuncSetAttribute(self.prg, cuda.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, self.smem)) def __del__(self): if hasattr(self, 'module'): check(cuda.cuModuleUnload(self.module)) @@ -56,7 +57,7 @@ class CUDAProgram: else: for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i]) for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i]) - return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs)), enable=wait) + return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, self.smem, None, None, self.vargs)), enable=wait) class CUDAAllocator(LRUAllocator): def __init__(self, device:CUDADevice):