mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-13 08:05:10 -05:00
compile cache for several devices (#3148)
* compile cache for several devices
* ops_gpu uses hash to not care about sql
* hip rdna test with device
* linter happy
* no device passed where possible
* arch is optional to compile_{hip|cuda}
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, Optional, List
|
||||
import ctypes, functools
|
||||
import ctypes, functools, hashlib
|
||||
import gpuctypes.opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -87,12 +87,15 @@ class CLDevice(Compiled):
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) # noqa: E501
|
||||
|
||||
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
|
||||
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256, ctypes.byref(buf := ctypes.create_string_buffer(256)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
|
||||
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256, ctypes.byref(buf := ctypes.create_string_buffer(256)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
|
||||
self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status) # noqa: E501
|
||||
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status)
|
||||
self.pending_copyin: List[memoryview] = []
|
||||
# TODO: vary the cache key based on device name
|
||||
|
||||
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
|
||||
super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer,
|
||||
functools.partial(compile_cl, self), "compile_cl", functools.partial(CLProgram, self))
|
||||
functools.partial(compile_cl, self), f"compile_cl_{compile_key}", functools.partial(CLProgram, self))
|
||||
def synchronize(self):
|
||||
check(cl.clFinish(self.queue))
|
||||
self.pending_copyin.clear()
|
||||
|
||||
Reference in New Issue
Block a user