viz: add metadata and var_vals tracing (#11753)

* viz: add metadata and var_vals tracing

* add test_trace_metadata

* set TRACEMETA=1
This commit is contained in:
qazal
2025-08-20 18:39:51 +03:00
committed by GitHub
parent 6589c9e643
commit de4cb722a4
4 changed files with 25 additions and 8 deletions

View File

@@ -1,6 +1,6 @@
import unittest, struct, contextlib, statistics, time, gc
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events, ProfilePointEvent, dedup
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.realize import get_runner
@@ -209,5 +209,18 @@ class TestProfiler(unittest.TestCase):
for ge in graphs:
self.assertEqual(len(ge.ents), len(graphs))
def test_trace_metadata(self):
with Context(TRACEMETA=1):
a = Tensor.empty(1)+2
b = Tensor.empty(1)+2
with helper_collect_profile(TestProfiler.d0) as profile:
Tensor.realize(a, b)
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
exec_points = [e for e in profile if isinstance(e, ProfilePointEvent) and e.name == "exec"]
range_events = [e for e in profile if isinstance(e, ProfileRangeEvent)]
self.assertEqual(len(exec_points), len(range_events), 2)
self.assertEqual(len(dedup(e.key for e in exec_points)), 1)
self.assertEqual(len(dedup(e.arg['metadata'] for e in exec_points)), 1)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,8 +1,8 @@
from typing import cast, Generator
import time, pprint
import time, pprint, decimal
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
@@ -149,6 +149,8 @@ class ExecItem:
def run(self, _var_vals:dict[Variable, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", decimal.Decimal(time.perf_counter_ns())/1000, self.prg.display_name,
{"metadata":self.metadata, "var_vals":var_vals}))
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
if do_update_stats:
GlobalCounters.kernel_count += 1

View File

@@ -205,7 +205,7 @@ class ProfileEvent: pass
class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
@dataclass(frozen=True)
class ProfilePointEvent(ProfileEvent): device:str; name:str; ts:decimal.Decimal; key:int; arg:dict=field(default_factory=dict) # noqa: E702
class ProfilePointEvent(ProfileEvent): device:str; name:str; ts:decimal.Decimal; key:Any; arg:dict=field(default_factory=dict) # noqa: E702
cpu_events:list[ProfileEvent] = []
@contextlib.contextmanager

View File

@@ -7,7 +7,7 @@ from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, Generator
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
@@ -126,7 +126,9 @@ def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decim
def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
shapes:list[dict] = []
levels:list[int] = []
exec_points:dict[str, dict] = {}
for st,et,dur,e in events:
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.key] = e.arg
if dur == 0: continue
# find a free level to put the event
depth = next((i for i,level_et in enumerate(levels) if st>=level_et), len(levels))
@@ -135,9 +137,9 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
name, cat, info = e.name, None, None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
# TODO: support symbolic by capturing var_vals in profile events
if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and all(isinstance(es,int) for es in [p.estimates.ops, p.estimates.mem, p.estimates.lds]):
info = f"{p.estimates.ops/(t:=dur*1e3):.2f} GFLOPS {p.estimates.mem/t:4.1f}|{p.estimates.lds/t:.1f} GB/s"
if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \
f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}"
elif isinstance(e.name, TracingKey):
name, cat = e.name.display_name, e.name.cat
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)