viz: remove ProgramSpec from trace (#14818)

This commit is contained in:
qazal
2026-02-17 18:04:58 +08:00
committed by GitHub
parent f590564bf7
commit 99a988b9d2
2 changed files with 10 additions and 7 deletions

View File

@@ -156,7 +156,7 @@ pm_to_program = PatternMatcher([
])
@Context(ALLOW_DEVICE_USAGE=0)
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast)), replay=True)
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.

View File

@@ -39,7 +39,6 @@ class HTTPRequestHandler(BaseHTTPRequestHandler):
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
@@ -69,8 +68,8 @@ def get_rewrites(t:RewriteTrace) -> list[dict]:
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)]
if (prg_idx:=next((j for j,s in enumerate(v) if s.name == "View Program"), None)) is not None:
_, device, lin, src, binary = _reconstruct(trace.rewrites[i][prg_idx].sink).src
if (p:=get_prg_uop(i)) is not None:
_, device, lin, src, binary = p.src
steps.append(create_step("View UOp List", ("/uops", i, len(steps)), lin.src))
steps.append(create_step("View Source", ("/code", i, len(steps)), src.arg))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (device.arg, binary.arg)))
@@ -162,6 +161,10 @@ def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails,
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink
def get_prg_uop(i:int) -> UOp|None:
s = next((s for s in trace.rewrites[i] if s.name == "View Program"), None)
return _reconstruct(s.sink) if s is not None else None
# encoder helpers
def enum_str(s, cache:dict[str, int]) -> int:
@@ -204,9 +207,9 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
name, fmt, key = e.name, [], None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
flops = sym_infer(p.estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(p.estimates.mem, var_vals)/t, sym_infer(p.estimates.lds, var_vals)/t
if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
(f"{membw*1e-9:.0f} GB/s" if membw < 1e13 else f"{membw*1e-12:.0f} TB/s")+" mem",
(f"{ldsbw*1e-9:.0f} GB/s" if ldsbw < 1e15 else f"{ldsbw*1e-12:.0f} TB/s")+" lds"]