mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
viz: add metadata and var_vals tracing (#11753)
* viz: add metadata and var_vals tracing * add test_trace_metadata * set TRACEMETA=1
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest, struct, contextlib, statistics, time, gc
|
||||
from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events
|
||||
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events, ProfilePointEvent, dedup
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.realize import get_runner
|
||||
@@ -209,5 +209,18 @@ class TestProfiler(unittest.TestCase):
|
||||
for ge in graphs:
|
||||
self.assertEqual(len(ge.ents), len(graphs))
|
||||
|
||||
def test_trace_metadata(self):
|
||||
with Context(TRACEMETA=1):
|
||||
a = Tensor.empty(1)+2
|
||||
b = Tensor.empty(1)+2
|
||||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
Tensor.realize(a, b)
|
||||
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
exec_points = [e for e in profile if isinstance(e, ProfilePointEvent) and e.name == "exec"]
|
||||
range_events = [e for e in profile if isinstance(e, ProfileRangeEvent)]
|
||||
self.assertEqual(len(exec_points), len(range_events), 2)
|
||||
self.assertEqual(len(dedup(e.key for e in exec_points)), 1)
|
||||
self.assertEqual(len(dedup(e.arg['metadata'] for e in exec_points)), 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import cast, Generator
|
||||
import time, pprint
|
||||
import time, pprint, decimal
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
@@ -149,6 +149,8 @@ class ExecItem:
|
||||
def run(self, _var_vals:dict[Variable, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
|
||||
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
|
||||
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
||||
if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", decimal.Decimal(time.perf_counter_ns())/1000, self.prg.display_name,
|
||||
{"metadata":self.metadata, "var_vals":var_vals}))
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
|
||||
@@ -205,7 +205,7 @@ class ProfileEvent: pass
|
||||
class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfilePointEvent(ProfileEvent): device:str; name:str; ts:decimal.Decimal; key:int; arg:dict=field(default_factory=dict) # noqa: E702
|
||||
class ProfilePointEvent(ProfileEvent): device:str; name:str; ts:decimal.Decimal; key:Any; arg:dict=field(default_factory=dict) # noqa: E702
|
||||
|
||||
cpu_events:list[ProfileEvent] = []
|
||||
@contextlib.contextmanager
|
||||
|
||||
@@ -7,7 +7,7 @@ from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, TypedDict, Generator
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -126,7 +126,9 @@ def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decim
|
||||
def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
|
||||
shapes:list[dict] = []
|
||||
levels:list[int] = []
|
||||
exec_points:dict[str, dict] = {}
|
||||
for st,et,dur,e in events:
|
||||
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.key] = e.arg
|
||||
if dur == 0: continue
|
||||
# find a free level to put the event
|
||||
depth = next((i for i,level_et in enumerate(levels) if st>=level_et), len(levels))
|
||||
@@ -135,9 +137,9 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
|
||||
name, cat, info = e.name, None, None
|
||||
if (ref:=ref_map.get(name)) is not None:
|
||||
name = ctxs[ref]["name"]
|
||||
# TODO: support symbolic by capturing var_vals in profile events
|
||||
if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and all(isinstance(es,int) for es in [p.estimates.ops, p.estimates.mem, p.estimates.lds]):
|
||||
info = f"{p.estimates.ops/(t:=dur*1e3):.2f} GFLOPS {p.estimates.mem/t:4.1f}|{p.estimates.lds/t:.1f} GB/s"
|
||||
if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
|
||||
info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \
|
||||
f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}"
|
||||
elif isinstance(e.name, TracingKey):
|
||||
name, cat = e.name.display_name, e.name.cat
|
||||
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
|
||||
|
||||
Reference in New Issue
Block a user