fix multigpu on tinybox (#2595)

* fix multigpu on tinybox

* fixed multigpu
This commit is contained in:
George Hotz
2023-12-03 16:48:07 -08:00
committed by GitHub
parent 61c0113928
commit fcd0b2ee6c
3 changed files with 19 additions and 10 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Tuple, Optional, Union, List, cast
import ctypes, functools
import gpuctypes.opencl as cl
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, diskcache, OSX, ImageDType
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, diskcache, OSX, ImageDType, DEBUG
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.device import Compiled, LRUAllocator
@@ -81,8 +81,12 @@ class CLDevice(Compiled):
if CLDevice.device_ids is None:
num_platforms = init_c_var(ctypes.c_uint32(), lambda x: check(cl.clGetPlatformIDs(0, None, ctypes.byref(x))))
platform_ids = init_c_var((cl.cl_platform_id * num_platforms.value)(), lambda x: check(cl.clGetPlatformIDs(num_platforms.value, x, None)))
num_devices = init_c_var(ctypes.c_uint32(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], cl.CL_DEVICE_TYPE_DEFAULT, 0, None, ctypes.byref(x))))
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], cl.CL_DEVICE_TYPE_DEFAULT, num_devices, x, None)))
for device_type in [cl.CL_DEVICE_TYPE_GPU, cl.CL_DEVICE_TYPE_DEFAULT]:
num_devices = ctypes.c_uint32()
err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, ctypes.byref(num_devices))
if err == 0 and num_devices.value != 0: break
if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices")
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None)))
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status)