From a4cb161bd430edf30a1f6345bbe726bce938505b Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 10 Feb 2023 21:51:53 -0600 Subject: [PATCH] log_kernel --- tinygrad/ops.py | 5 +++++ tinygrad/runtime/cuda.py | 4 +--- tinygrad/runtime/opencl.py | 4 +--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 51727e1a76..cfff00e4e6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -86,6 +86,11 @@ class GlobalCounters: cache : ClassVar[Optional[list]] = None @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0,0,None + @staticmethod + def log_kernel(op_estimate:int, mem_estimate:int): + GlobalCounters.kernel_count += 1 + GlobalCounters.global_ops += op_estimate + GlobalCounters.global_mem += mem_estimate # assumes you are using ShapeTracker # used in GPUBuffer and LLVMBuffer diff --git a/tinygrad/runtime/cuda.py b/tinygrad/runtime/cuda.py index 5c3c37d0f5..9235de00a0 100644 --- a/tinygrad/runtime/cuda.py +++ b/tinygrad/runtime/cuda.py @@ -26,6 +26,4 @@ class CLProgram: global_size = [x//y for x,y in zip(global_size, local_size)] if DEBUG >= 2: print("CUDA launch", global_size, local_size) self.prg(*args, block=tuple(local_size), grid=tuple(global_size)) - GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += self.op_estimate - GlobalCounters.global_mem += self.mem_estimate \ No newline at end of file + GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate) \ No newline at end of file diff --git a/tinygrad/runtime/opencl.py b/tinygrad/runtime/opencl.py index d42f47f7a6..6b7aef44ad 100644 --- a/tinygrad/runtime/opencl.py +++ b/tinygrad/runtime/opencl.py @@ -91,7 +91,5 @@ class CLProgram: if DEBUG >= 1: print(f"**CL** {GlobalCounters.kernel_count:6d} {self.name:28s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " + (str() if DEBUG <= 1 or CL.CACHE is not None else f"tm {et/1e3:9.2f}us/{GlobalCounters.time_sum/1e6:9.2f}ms ({self.op_estimate/et:8.2f} GFLOPS)")) - GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += self.op_estimate - GlobalCounters.global_mem += self.mem_estimate + GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate) return e if CL.CACHE is None else None