diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index cb4af7beb4..a2d387a0e3 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -9,6 +9,27 @@ from tinygrad.device import Device, Buffer from tinygrad.renderer import ProgramSpec, Estimates from tinygrad.codegen import get_program +# **************** Stat **************** + +def update_stats(display_name:str, device:str, estimates:Estimates, var_vals:dict[str, int], et:float|None, buf_count:int, + jit=False, metadata:tuple[Metadata, ...]=(), first_run=False): + GlobalCounters.kernel_count += 1 + GlobalCounters.global_ops += (op_est:=sym_infer(estimates.ops, var_vals)) + GlobalCounters.global_mem += (mem_est:=sym_infer(estimates.mem, var_vals)) + if et is not None: GlobalCounters.time_sum_s += et + if DEBUG >= 2: + lds_est = sym_infer(estimates.lds, var_vals) + header_color = 'magenta' if jit else ('green' if first_run else None) + ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" + flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20) + flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green') + mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \ + colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green') + print(f"{colored(f'*** {device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+ + f" {display_name+' '*(46-ansilen(display_name))} arg {buf_count:2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+ + ("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+ + f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in metadata] if metadata else ''}") + # **************** Runners **************** class Runner: @@ -166,22 +187,7 @@ class ExecItem: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", len(cpu_events), payload)) et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2) if do_update_stats: - GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals)) - GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals)) - if et is not None: GlobalCounters.time_sum_s += et - if DEBUG >= 2: - lds_est = sym_infer(self.prg.estimates.lds, var_vals) - header_color = 'magenta' if jit else ('green' if self.prg.first_run else None) - ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" - flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20) - flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green') - mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \ - colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green') - print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+ - f" {self.prg.display_name+' '*(46-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+ - ("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+ - f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}") + update_stats(self.prg.display_name, self.prg.device, self.prg.estimates, var_vals, et, len(bufs), jit, self.metadata, self.prg.first_run) self.prg.first_run = False return et