gpu: rename kernels

This commit is contained in:
George Hotz
2023-01-09 19:32:22 -08:00
parent 1e1abb450e
commit 4356683081

View File

@@ -43,21 +43,22 @@ class CL:
@staticmethod
def enqueue_copy(a, b, is_blocking=False):
if CL.CACHE is not None:
assert False, "can't copy while caching"
assert False, f"can't copy {a} -> {b} while caching"
if DEBUG >= 1:
print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}")
cl.enqueue_copy(CL().cl_queue, a, b, is_blocking=is_blocking)
@functools.lru_cache(maxsize=None)
class CLProgram:
kernel_cnt = 0
kernel_cnt = defaultdict(int)
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False):
self.name, self.prg, self.options, self.argdtypes = f"{name}_{CLProgram.kernel_cnt}" if rename else name, prg.replace(f"{name}(", f"{name}_{CLProgram.kernel_cnt}(") if rename else prg, options, argdtypes
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else ''}" if rename else name
self.prg, self.options, self.argdtypes = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes
self.clprogram = cl.Program(CL().cl_ctx, CL().cl_ctx.devices, [self.prg]) if binary else cl.Program(CL().cl_ctx, self.prg) # type: ignore
self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name)
if self.argdtypes is not None:
self.clprg.set_scalar_arg_dtypes(self.argdtypes)
CLProgram.kernel_cnt += 1
CLProgram.kernel_cnt[name] += 1
def __call__(self, *args, op_estimate=0):
CL.kernel_count += 1
if CL.CACHE is not None:
@@ -71,7 +72,7 @@ class CLProgram:
if DEBUG >= 1:
CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else (e.profile.end - e.profile.start)
CL.ops_sum += op_estimate
print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:7.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " +
print(f"**CL** {CL.kernel_count:6d} {self.name:28s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:7.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " +
("" if DEBUG <= 1 or CL.CACHE is not None else f"tm {(e.profile.end - e.profile.start)/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({op_estimate/(e.profile.end - e.profile.start):8.2f} GFLOPS)"))
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += sum([x.size//4 for x in args[2:] if isinstance(x, cl.Buffer)])