From 463dece63e2ab7a996ee480401c3a1e5f699ecf3 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 22 Aug 2023 10:22:40 -0700 Subject: [PATCH] auto arg dtypes (#1623) --- tinygrad/runtime/ops_gpu.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 6cfa532f68..0aba0e7958 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -62,7 +62,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): class CLProgram: def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None): - self.name, self.argdtypes, self.clprograms = name, argdtypes, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore + self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore try: self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms] except cl.RuntimeError as e: @@ -78,15 +78,17 @@ class CLProgram: else: # print the PTX for NVIDIA. TODO: probably broken for everything else print(self.binary().decode('utf-8')) - if self.argdtypes is not None: _ = [clprg.set_scalar_arg_dtypes(self.argdtypes) for clprg in self.clprgs] + if argdtypes is not None: self.set_argdtypes(argdtypes) def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0] + def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs] @staticmethod def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]: - cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs] + if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if isinstance(x, CLBuffer) else np.int32 for x in bufs)) + cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs] e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")]) if wait: e.wait()