mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 23:54:58 -05:00
fix multigpu on tinybox (#2595)
* fix multigpu on tinybox * fixed multigpu
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Tuple, Optional, Union, List, cast
|
||||
import ctypes, functools
|
||||
import gpuctypes.opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, diskcache, OSX, ImageDType
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, diskcache, OSX, ImageDType, DEBUG
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
from tinygrad.device import Compiled, LRUAllocator
|
||||
@@ -81,8 +81,12 @@ class CLDevice(Compiled):
|
||||
if CLDevice.device_ids is None:
|
||||
num_platforms = init_c_var(ctypes.c_uint32(), lambda x: check(cl.clGetPlatformIDs(0, None, ctypes.byref(x))))
|
||||
platform_ids = init_c_var((cl.cl_platform_id * num_platforms.value)(), lambda x: check(cl.clGetPlatformIDs(num_platforms.value, x, None)))
|
||||
num_devices = init_c_var(ctypes.c_uint32(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], cl.CL_DEVICE_TYPE_DEFAULT, 0, None, ctypes.byref(x))))
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], cl.CL_DEVICE_TYPE_DEFAULT, num_devices, x, None)))
|
||||
for device_type in [cl.CL_DEVICE_TYPE_GPU, cl.CL_DEVICE_TYPE_DEFAULT]:
|
||||
num_devices = ctypes.c_uint32()
|
||||
err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, ctypes.byref(num_devices))
|
||||
if err == 0 and num_devices.value != 0: break
|
||||
if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices")
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None)))
|
||||
|
||||
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
|
||||
self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status)
|
||||
|
||||
Reference in New Issue
Block a user