diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 1259552c4b..ea463478b7 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -12,10 +12,10 @@ class CL: CACHE = None def __init__(self): if getattr(CL, "cl_queue", None) is not None: return - devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU) + devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], []) if len(devices) == 0: # settle for CPU - devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.CPU) - CL.cl_ctx = cl.Context(devices=devices) + devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], []) + CL.cl_ctx = cl.Context(devices=[devices[int(os.getenv("CL_DEVICE", "0"))]]) CL.cl_queue = cl.CommandQueue(self.cl_ctx) # this is an in-order command queue @staticmethod