mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
auto arg dtypes (#1623)
This commit is contained in:
@@ -62,7 +62,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
|
||||
self.name, self.argdtypes, self.clprograms = name, argdtypes, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
|
||||
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
|
||||
try:
|
||||
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
|
||||
except cl.RuntimeError as e:
|
||||
@@ -78,15 +78,17 @@ class CLProgram:
|
||||
else:
|
||||
# print the PTX for NVIDIA. TODO: probably broken for everything else
|
||||
print(self.binary().decode('utf-8'))
|
||||
if self.argdtypes is not None: _ = [clprg.set_scalar_arg_dtypes(self.argdtypes) for clprg in self.clprgs]
|
||||
if argdtypes is not None: self.set_argdtypes(argdtypes)
|
||||
|
||||
def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0]
|
||||
def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
|
||||
|
||||
@staticmethod
|
||||
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
|
||||
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if isinstance(x, CLBuffer) else np.int32 for x in bufs))
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
|
||||
if wait:
|
||||
e.wait()
|
||||
|
||||
Reference in New Issue
Block a user