auto arg dtypes (#1623)

This commit is contained in:
George Hotz
2023-08-22 10:22:40 -07:00
committed by GitHub
parent db8344ab83
commit 463dece63e

View File

@@ -62,7 +62,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
self.name, self.argdtypes, self.clprograms = name, argdtypes, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
try:
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
except cl.RuntimeError as e:
@@ -78,15 +78,17 @@ class CLProgram:
else:
# print the PTX for NVIDIA. TODO: probably broken for everything else
print(self.binary().decode('utf-8'))
if self.argdtypes is not None: _ = [clprg.set_scalar_arg_dtypes(self.argdtypes) for clprg in self.clprgs]
if argdtypes is not None: self.set_argdtypes(argdtypes)
def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0]
def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
@staticmethod
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if isinstance(x, CLBuffer) else np.int32 for x in bufs))
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
if wait:
e.wait()