mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
compile cache for several devices (#3148)
* compile cache for several devices
* ops_gpu uses hash to not care about sql
* hip rdna test with device
* linter happy
* no device passed where possible
* arch is optional to compile_{hip|cuda}
This commit is contained in:
@@ -29,7 +29,7 @@ def check(status):
|
||||
|
||||
def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) # noqa: E501
|
||||
|
||||
def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) # noqa: E501
|
||||
def compile_cuda(prg:str, arch="sm_35") -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) # noqa: E501
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, device:CUDADevice, name:str, lib:bytes):
|
||||
@@ -39,7 +39,7 @@ class CUDAProgram:
|
||||
try:
|
||||
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
||||
with open(fn + ".ptx", "wb") as f: f.write(lib)
|
||||
subprocess.run(["ptxas", f"-arch={CUDADevice.default_arch_name}", "-o", fn, fn+".ptx"], check=True)
|
||||
subprocess.run(["ptxas", f"-arch={device.arch}", "-o", fn, fn+".ptx"], check=True)
|
||||
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
||||
except Exception as e: print("failed to generate SASS", str(e))
|
||||
|
||||
@@ -73,7 +73,6 @@ class CUDAAllocator(LRUAllocator):
|
||||
check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest)))
|
||||
|
||||
class CUDADevice(Compiled):
|
||||
default_arch_name = "sm_35"
|
||||
def __init__(self, device:str):
|
||||
device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
if not CUDACPU:
|
||||
@@ -81,12 +80,13 @@ class CUDADevice(Compiled):
|
||||
check(cuda.cuDeviceGet(ctypes.byref(device := cuda.CUdevice()), device_id))
|
||||
self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, device)))
|
||||
check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))
|
||||
if device_id == 0: CUDADevice.default_arch_name = f"sm_{major.value}{minor.value}"
|
||||
self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator,
|
||||
LinearizerOptions("CUDA", supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
|
||||
CUDARenderer, compile_cuda, "compile_cuda", functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
|
||||
CUDARenderer, functools.partial(compile_cuda,arch=self.arch), f"compile_cuda_{self.arch}", functools.partial(CUDAProgram, self),
|
||||
graph=CUDAGraph if not CUDACPU else None)
|
||||
def synchronize(self):
|
||||
if not CUDACPU:
|
||||
check(cuda.cuCtxSetCurrent(self.context))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, Optional, List
|
||||
import ctypes, functools
|
||||
import ctypes, functools, hashlib
|
||||
import gpuctypes.opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -87,12 +87,15 @@ class CLDevice(Compiled):
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) # noqa: E501
|
||||
|
||||
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
|
||||
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256, ctypes.byref(buf := ctypes.create_string_buffer(256)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
|
||||
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256, ctypes.byref(buf := ctypes.create_string_buffer(256)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
|
||||
self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status) # noqa: E501
|
||||
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status)
|
||||
self.pending_copyin: List[memoryview] = []
|
||||
# TODO: vary the cache key based on device name
|
||||
|
||||
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
|
||||
super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer,
|
||||
functools.partial(compile_cl, self), "compile_cl", functools.partial(CLProgram, self))
|
||||
functools.partial(compile_cl, self), f"compile_cl_{compile_key}", functools.partial(CLProgram, self))
|
||||
def synchronize(self):
|
||||
check(cl.clFinish(self.queue))
|
||||
self.pending_copyin.clear()
|
||||
|
||||
@@ -17,7 +17,7 @@ def check(status):
|
||||
# TODO: remove these helpers, they increase complexity
|
||||
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
|
||||
|
||||
def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}', '-I/opt/rocm/include'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) # noqa: E501
|
||||
def compile_hip(prg:str, arch="gfx1100") -> bytes: return compile_cuda_style(prg, [f'--offload-arch={arch}', '-I/opt/rocm/include'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) # noqa: E501
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, device:int, name:str, lib:bytes):
|
||||
@@ -81,15 +81,14 @@ class HIPAllocator(LRUAllocator):
|
||||
check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice))
|
||||
|
||||
class HIPDevice(Compiled):
|
||||
default_arch_name = "gfx1100"
|
||||
def __init__(self, device:str=""):
|
||||
self.device = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() if not MOCKHIP else "gfx1100" # noqa: E501
|
||||
self.pending_copyin: List[hip.hipDeviceptr_t] = []
|
||||
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
|
||||
|
||||
from tinygrad.runtime.graph.hip import HIPGraph
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
|
||||
compile_hip, "compile_hip", functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
functools.partial(compile_hip,arch=self.arch), f"compile_hip_{self.arch}", functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
def synchronize(self):
|
||||
check(hip.hipSetDevice(self.device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
|
||||
Reference in New Issue
Block a user