mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
split update stat from execitem (#15654)
This commit is contained in:
@@ -9,6 +9,27 @@ from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import ProgramSpec, Estimates
|
||||
from tinygrad.codegen import get_program
|
||||
|
||||
# **************** Stat ****************
|
||||
|
||||
def update_stats(display_name:str, device:str, estimates:Estimates, var_vals:dict[str, int], et:float|None, buf_count:int,
|
||||
jit=False, metadata:tuple[Metadata, ...]=(), first_run=False):
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_est:=sym_infer(estimates.ops, var_vals))
|
||||
GlobalCounters.global_mem += (mem_est:=sym_infer(estimates.mem, var_vals))
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
lds_est = sym_infer(estimates.lds, var_vals)
|
||||
header_color = 'magenta' if jit else ('green' if first_run else None)
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
|
||||
flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green')
|
||||
mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \
|
||||
colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green')
|
||||
print(f"{colored(f'*** {device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
|
||||
f" {display_name+' '*(46-ansilen(display_name))} arg {buf_count:2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+
|
||||
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
|
||||
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in metadata] if metadata else ''}")
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
class Runner:
|
||||
@@ -166,22 +187,7 @@ class ExecItem:
|
||||
cpu_events.append(ProfilePointEvent(self.prg.device, "exec", len(cpu_events), payload))
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
|
||||
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
|
||||
header_color = 'magenta' if jit else ('green' if self.prg.first_run else None)
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
|
||||
flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green')
|
||||
mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \
|
||||
colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green')
|
||||
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
|
||||
f" {self.prg.display_name+' '*(46-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+
|
||||
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
|
||||
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")
|
||||
update_stats(self.prg.display_name, self.prg.device, self.prg.estimates, var_vals, et, len(bufs), jit, self.metadata, self.prg.first_run)
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
|
||||
Reference in New Issue
Block a user