mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
cleanup ops_gpu (#1566)
This commit is contained in:
@@ -30,17 +30,20 @@ class CLAllocator(LRUAllocator):
|
||||
return buf
|
||||
|
||||
class _CL:
|
||||
def __init__(self):
|
||||
cl_platforms = cl.get_platforms()
|
||||
platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if len(y)]
|
||||
self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
|
||||
self.cl_platform = self.devices[0].platform
|
||||
def post_init(self, device=None):
|
||||
platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)]
|
||||
self.cl_platform = cl.get_platforms()[getenv('CL_PLATFORM', 0)]
|
||||
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])]
|
||||
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])]
|
||||
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
|
||||
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
|
||||
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
|
||||
def synchronize(self):
|
||||
for q in self.cl_queue: q.finish()
|
||||
CL = _CL()
|
||||
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
|
||||
if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
|
||||
|
||||
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
|
||||
|
||||
Reference in New Issue
Block a user