cl fixes for multigpu (#1276)

* feat: opencl fixes for multigpu usage

* clean: who needs this import anyways
This commit is contained in:
wozeparrot
2023-07-18 22:59:30 -04:00
committed by GitHub
parent fa0265b173
commit 37cc33269a
2 changed files with 33 additions and 30 deletions

View File

@@ -74,29 +74,29 @@ class Thneed:
if o['arg_type'] == "image2d_t":
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
# hack: use a image1d since we can back that with a buffer
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
# buffer isn't supported in image2d, copy buffer into image
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
cl.enqueue_copy(CL.cl_queue[0], arr, bufs[o['buffer_id']])
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
elif o['needs_load']:
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
else:
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
if o['arg_type'] == "image1d_t":
assert not o['needs_load']
assert not bufs_loaded[o['buffer_id']]
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
if 'data' in o:
buf = cl.Buffer(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
else:
# zero out buffers
buf = cl.Buffer(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
bufs[o['id']] = buf
bufs_loaded[o['id']] = 'data' in o
@@ -161,7 +161,7 @@ class Thneed:
for prg, args in self.cl_cache:
# get binaries for saving
if prg.name not in saved_binaries:
binary = prg.clprogram.get_info(cl.program_info.BINARIES)
binary = prg.clprograms[0].get_info(cl.program_info.BINARIES)
assert len(binary) == 1
jdat['binaries'].append({"name":prg.name, "length":len(binary[0])})
binaries.append(binary[0])
@@ -201,7 +201,7 @@ class Thneed:
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
size = row_pitch * a.shape[1]
# this is *2 if float16 and *4 if float32
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
buf = cl.Buffer(CL.cl_ctxs[0], cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
# zero out the buffer
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
@@ -271,7 +271,7 @@ class Thneed:
events = []
st = time.monotonic()
for prg, args in self.cl_cache:
events.append(prg.clprg(CL.cl_queue[0], *args))
events.append(prg.clprgs[0](CL.cl_queue[0], *args))
mt = time.monotonic()
CL.synchronize()
et = time.monotonic() - st

View File

@@ -17,65 +17,68 @@ if DEBUG >= 5:
early_exec = fromimport("extra.helpers", "enable_early_exec")()
class _CL:
def __init__(self):
def __init__(self): self.events_in_flight = []
def post_init(self, device=None):
platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)]
devices: List[cl.Device] = [x for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', '').split(',')]
if DEBUG >= 1: print(f"using devices: {[d.hashable_model_and_version_identifier for d in devices]}")
self.cl_ctx: cl.Context = cl.Context(devices=devices)
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(self.cl_ctx, device=device, properties=cl.command_queue_properties.PROFILING_ENABLE) for device in self.cl_ctx.devices]
self.cl_platform = cl.get_platforms()[getenv('CL_PLATFORM', 0)]
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])]
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
def synchronize(self):
for evt in self.events_in_flight: evt.wait()
self.events_in_flight.clear()
for q in self.cl_queue: q.finish()
CL = _CL()
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
class CLBuffer(RawBufferCopyInOut):
def __init__(self, size, dtype, device='0'):
if isinstance(dtype, ImageDType):
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
else:
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
super().__init__(size, dtype, buf)
def _copyin(self, x: np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
def _copyout(self, x: np.ndarray):
CL.events_in_flight.append(cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False))
def _copyout(self, x:np.ndarray):
CL.synchronize()
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
cl.enqueue_copy(CL.cl_queue[self._buf.device], x, self._buf, is_blocking=True)
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
self.name, self.argdtypes, self.clprogram = name, argdtypes, cl.Program(CL.cl_ctx, CL.cl_ctx.devices, [prg]*len(CL.cl_ctx.devices)) if binary else cl.Program(CL.cl_ctx, prg) # type: ignore
self.name, self.argdtypes, self.clprograms = name, argdtypes, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
try:
self._clprg = self.clprogram.build(options=options)
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
except cl.RuntimeError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
self.clprg = self._clprg.__getattr__(name)
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
if DEBUG >= 5 and not OSX:
if 'Adreno' in CL.cl_ctx.devices[0].name:
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
fromimport('disassemblers.adreno', 'disasm')(self.binary())
elif CL.cl_ctx.devices[0].name.startswith('gfx'):
elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'):
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary()))
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
else:
# print the PTX for NVIDIA. TODO: probably broken for everything else
print(self.binary().decode('utf-8'))
if self.argdtypes is not None: self.clprg.set_scalar_arg_dtypes(self.argdtypes)
if self.argdtypes is not None: _ = [clprg.set_scalar_arg_dtypes(self.argdtypes) for clprg in self.clprgs]
def binary(self): return self.clprogram.get_info(cl.program_info.BINARIES)[0]
def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0]
@staticmethod
def max_work_group_size(): return CL.cl_ctx.devices[0].max_work_group_size
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
e = self.clprg(CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs)
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs)
if wait:
e.wait()
try: