From c65ad43a9357bebe8bb21769336d792b44e5fdd0 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 17 Aug 2023 23:43:08 -0400 Subject: [PATCH] cleanup ops_gpu (#1566) --- tinygrad/runtime/ops_gpu.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index c800010b00..9182d6f735 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -30,17 +30,20 @@ class CLAllocator(LRUAllocator): return buf class _CL: + def __init__(self): + cl_platforms = cl.get_platforms() + platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if len(y)] + self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")] + self.cl_platform = self.devices[0].platform def post_init(self, device=None): - platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)] - self.cl_platform = cl.get_platforms()[getenv('CL_PLATFORM', 0)] - self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])] + self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])] if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}") self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs] self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE)) def synchronize(self): for q in self.cl_queue: q.finish() CL = _CL() -CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None +if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init() class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})