From 99a988b9d2157ab110fd504bb7dfa7093823450e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 17 Feb 2026 18:04:58 +0800 Subject: [PATCH] viz: remove ProgramSpec from trace (#14818) --- tinygrad/codegen/__init__.py | 2 +- tinygrad/viz/serve.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index a5afb844d9..07625a0223 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -156,7 +156,7 @@ pm_to_program = PatternMatcher([ ]) @Context(ALLOW_DEVICE_USAGE=0) -@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) +@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast)), replay=True) def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec: """ Transform an AST into a ProgramSpec. May trigger BEAM search. diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 05988552bc..0e6f5eb21a 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -39,7 +39,6 @@ class HTTPRequestHandler(BaseHTTPRequestHandler): from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender from tinygrad.uop.ops import print_uops, range_start, multirange_str from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent -from tinygrad.renderer import ProgramSpec from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", @@ -69,8 +68,8 @@ def get_rewrites(t:RewriteTrace) -> list[dict]: for i,(k,v) in enumerate(zip(t.keys, t.rewrites)): steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc), trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)] - if (prg_idx:=next((j for j,s in enumerate(v) if s.name == "View Program"), None)) is not None: - _, device, lin, src, binary = _reconstruct(trace.rewrites[i][prg_idx].sink).src + if (p:=get_prg_uop(i)) is not None: + _, device, lin, src, binary = p.src steps.append(create_step("View UOp List", ("/uops", i, len(steps)), lin.src)) steps.append(create_step("View Source", ("/code", i, len(steps)), src.arg)) steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (device.arg, binary.arg))) @@ -162,6 +161,10 @@ def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, "diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)} if not ctx.bottom_up: next_sink = new_sink +def get_prg_uop(i:int) -> UOp|None: + s = next((s for s in trace.rewrites[i] if s.name == "View Program"), None) + return _reconstruct(s.sink) if s is not None else None + # encoder helpers def enum_str(s, cache:dict[str, int]) -> int: @@ -204,9 +207,9 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts: name, fmt, key = e.name, [], None if (ref:=ref_map.get(name)) is not None: name = ctxs[ref]["name"] - if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None: - flops = sym_infer(p.estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6) - membw, ldsbw = sym_infer(p.estimates.mem, var_vals)/t, sym_infer(p.estimates.lds, var_vals)/t + if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None: + flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6) + membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS", (f"{membw*1e-9:.0f} GB/s" if membw < 1e13 else f"{membw*1e-12:.0f} TB/s")+" mem", (f"{ldsbw*1e-9:.0f} GB/s" if ldsbw < 1e15 else f"{ldsbw*1e-12:.0f} TB/s")+" lds"]