diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 5ba00735eb..2ca4c2e230 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -16,10 +16,11 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non return sink # real VIZ=1 pickles these tracked values -from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt -traces = [(tracked_keys, tracked_ctxs, uop_fields)] +from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt, RewriteTrace +from tinygrad.viz import serve +serve.trace = RewriteTrace(tracked_keys, tracked_ctxs, uop_fields) from tinygrad.viz.serve import get_metadata, uop_to_json, get_details -def get_viz_list(): return get_metadata(traces) +def get_viz_list(): return get_metadata(serve.trace) def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]: lst = get_viz_list() assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}" diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b2df8889c4..ca3ed70f22 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1043,6 +1043,9 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][2] += time.perf_counter()-st return None +@dataclass(frozen=True) +class RewriteTrace: keys:list[TracingKey]; rewrites:list[list[TrackedGraphRewrite]]; uop_fields:dict[int, tuple] # noqa: E702 + if TRACK_MATCH_STATS or PROFILE: PatternMatcher = TrackedPatternMatcher # type: ignore import atexit @@ -1051,7 +1054,7 @@ if TRACK_MATCH_STATS or PROFILE: if TRACK_MATCH_STATS >= 2: with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") - pickle.dump([(tracked_keys, tracked_ctxs, uop_fields)], f) + pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f) if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value): ret = [0,0,0.0,0.0] diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 7a4b825319..ed42c67429 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -7,7 +7,7 @@ from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, TypedDict, Generator from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp -from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender +from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device from tinygrad.renderer import ProgramSpec from tinygrad.dtype import dtypes @@ -24,26 +24,22 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", # VIZ API -# ** Metadata for a track_rewrites scope +# ** list all saved rewrites ref_map:dict[Any, int] = {} -traces:dict[int, tuple] = {} -def get_metadata(trace_bufs:list[tuple]) -> list[dict]: +def get_metadata(t:RewriteTrace) -> list[dict]: ret = [] - for keys,contexts,uop_fields in trace_bufs: - for k,v in zip(keys, contexts): - traces[i:=len(traces)] = (k, v, uop_fields) - steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc), - "query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)] - ret.append({"name":k.display_name, "steps":steps}) - # program spec metadata - if isinstance(k.ret, ProgramSpec): - steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"}) - steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"}) - for key in k.keys: ref_map[key] = i + for i,(k,v) in enumerate(zip(t.keys, t.rewrites)): + steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc), + "query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)] + if isinstance(k.ret, ProgramSpec): + steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"}) + steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"}) + for key in k.keys: ref_map[key] = i + ret.append({"name":k.display_name, "steps":steps}) return ret -# ** Complete rewrite details for a graph_rewrite call +# ** get the complete UOp graphs for one rewrite class GraphRewriteDetails(TypedDict): graph: dict # JSON serialized UOp for this rewrite step @@ -56,7 +52,7 @@ def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")" def pystr(u:UOp, i:int) -> str: try: - return "\n".join(pyrender(u)) if isinstance(traces[i][0].ret, ProgramSpec) else str(u) + return "\n".join(pyrender(u)) if isinstance(trace.keys[i].ret, ProgramSpec) else str(u) except Exception: return "issue in pyrender" def uop_to_json(x:UOp) -> dict[int, dict]: @@ -93,16 +89,16 @@ def uop_to_json(x:UOp) -> dict[int, dict]: return graph @functools.cache -def _reconstruct(a:int, i:int): - op, dtype, src, arg, *rest = traces[i][2][a] - arg = type(arg)(_reconstruct(arg.ast, i), arg.metadata) if op is Ops.KERNEL else arg - return UOp(op, dtype, tuple(_reconstruct(s, i) for s in src), arg, *rest) +def _reconstruct(a:int): + op, dtype, src, arg, *rest = trace.uop_fields[a] + arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg + return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest) def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]: - yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink, i)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None} + yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches): - replaces[u0:=_reconstruct(u0_num, i)] = u1 = _reconstruct(u1_num, i) + replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num) try: new_sink = next_sink.substitute(replaces) except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc) @@ -145,7 +141,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts: name, info = e.name, None if (ref:=ref_map.get(name)) is not None: name = ctxs[ref]["name"] - if isinstance(p:=traces[ref][0].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None: + if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None: info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \ f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}" elif isinstance(e.name, TracingKey): @@ -226,7 +222,7 @@ def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict: return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary} def get_render(ctx:list[str], fmt:list[str]): - if not isinstance(prg:=traces[int(ctx[0])][0].ret, ProgramSpec): return + if not isinstance(prg:=trace.keys[int(ctx[0])].ret, ProgramSpec): return if fmt[0] == "src": return json.dumps({"src":prg.src, "lang":"cpp"}).encode() lib = (compiler:=Device[prg.device].compiler).compile(prg.src) with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib) @@ -256,7 +252,7 @@ class Handler(BaseHTTPRequestHandler): elif (query:=parse_qs(url.query)): if url.path == "/render": ret, content_type = get_render(**query), "application/json" else: - try: return self.stream_json(get_details(traces[i:=int(query["ctx"][0])][1][int(query["idx"][0])], i)) + try: return self.stream_json(get_details(trace.rewrites[i:=int(query["ctx"][0])][int(query["idx"][0])], i)) except KeyError: status_code = 404 elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json" elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" @@ -313,7 +309,7 @@ if __name__ == "__main__": st = time.perf_counter() print("*** viz is starting") - ctxs = get_metadata(args.kernels) + ctxs = get_metadata(trace:=args.kernels) profile_ret = get_profile(args.profile) server = TCPServerWithReuse(('', PORT), Handler)