From 4356683081738832d32ca5fe032546b099cc12e8 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 9 Jan 2023 19:32:22 -0800 Subject: [PATCH] gpu: rename kernels --- tinygrad/llops/ops_gpu.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 40be62cd14..03ef3157ac 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -43,21 +43,22 @@ class CL: @staticmethod def enqueue_copy(a, b, is_blocking=False): if CL.CACHE is not None: - assert False, "can't copy while caching" + assert False, f"can't copy {a} -> {b} while caching" if DEBUG >= 1: print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}") cl.enqueue_copy(CL().cl_queue, a, b, is_blocking=is_blocking) @functools.lru_cache(maxsize=None) class CLProgram: - kernel_cnt = 0 + kernel_cnt = defaultdict(int) def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False): - self.name, self.prg, self.options, self.argdtypes = f"{name}_{CLProgram.kernel_cnt}" if rename else name, prg.replace(f"{name}(", f"{name}_{CLProgram.kernel_cnt}(") if rename else prg, options, argdtypes + self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else ''}" if rename else name + self.prg, self.options, self.argdtypes = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes self.clprogram = cl.Program(CL().cl_ctx, CL().cl_ctx.devices, [self.prg]) if binary else cl.Program(CL().cl_ctx, self.prg) # type: ignore self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name) if self.argdtypes is not None: self.clprg.set_scalar_arg_dtypes(self.argdtypes) - CLProgram.kernel_cnt += 1 + CLProgram.kernel_cnt[name] += 1 def __call__(self, *args, op_estimate=0): CL.kernel_count += 1 if CL.CACHE is not None: @@ -71,7 +72,7 @@ class CLProgram: if DEBUG >= 1: CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else (e.profile.end - e.profile.start) CL.ops_sum += op_estimate - print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:7.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " + + print(f"**CL** {CL.kernel_count:6d} {self.name:28s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:7.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " + ("" if DEBUG <= 1 or CL.CACHE is not None else f"tm {(e.profile.end - e.profile.start)/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({op_estimate/(e.profile.end - e.profile.start):8.2f} GFLOPS)")) GlobalCounters.global_ops += op_estimate GlobalCounters.global_mem += sum([x.size//4 for x in args[2:] if isinstance(x, cl.Buffer)])