cleanup viz server (#12688)

This commit is contained in:
qazal
2025-10-15 15:58:36 +08:00
committed by GitHub
parent aa81bde150
commit f0268d13f6
3 changed files with 31 additions and 31 deletions

View File

@@ -16,10 +16,11 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
return sink
# real VIZ=1 pickles these tracked values
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt
traces = [(tracked_keys, tracked_ctxs, uop_fields)]
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt, RewriteTrace
from tinygrad.viz import serve
serve.trace = RewriteTrace(tracked_keys, tracked_ctxs, uop_fields)
from tinygrad.viz.serve import get_metadata, uop_to_json, get_details
def get_viz_list(): return get_metadata(traces)
def get_viz_list(): return get_metadata(serve.trace)
def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
lst = get_viz_list()
assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}"

View File

@@ -1043,6 +1043,9 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += time.perf_counter()-st
return None
@dataclass(frozen=True)
class RewriteTrace: keys:list[TracingKey]; rewrites:list[list[TrackedGraphRewrite]]; uop_fields:dict[int, tuple] # noqa: E702
if TRACK_MATCH_STATS or PROFILE:
PatternMatcher = TrackedPatternMatcher # type: ignore
import atexit
@@ -1051,7 +1054,7 @@ if TRACK_MATCH_STATS or PROFILE:
if TRACK_MATCH_STATS >= 2:
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
pickle.dump([(tracked_keys, tracked_ctxs, uop_fields)], f)
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
ret = [0,0,0.0,0.0]

View File

@@ -7,7 +7,7 @@ from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, Generator
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
@@ -24,26 +24,22 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
# VIZ API
# ** Metadata for a track_rewrites scope
# ** list all saved rewrites
ref_map:dict[Any, int] = {}
traces:dict[int, tuple] = {}
def get_metadata(trace_bufs:list[tuple]) -> list[dict]:
def get_metadata(t:RewriteTrace) -> list[dict]:
ret = []
for keys,contexts,uop_fields in trace_bufs:
for k,v in zip(keys, contexts):
traces[i:=len(traces)] = (k, v, uop_fields)
steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc),
"query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)]
ret.append({"name":k.display_name, "steps":steps})
# program spec metadata
if isinstance(k.ret, ProgramSpec):
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
for key in k.keys: ref_map[key] = i
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc),
"query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)]
if isinstance(k.ret, ProgramSpec):
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
for key in k.keys: ref_map[key] = i
ret.append({"name":k.display_name, "steps":steps})
return ret
# ** Complete rewrite details for a graph_rewrite call
# ** get the complete UOp graphs for one rewrite
class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
@@ -56,7 +52,7 @@ def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
def pystr(u:UOp, i:int) -> str:
try:
return "\n".join(pyrender(u)) if isinstance(traces[i][0].ret, ProgramSpec) else str(u)
return "\n".join(pyrender(u)) if isinstance(trace.keys[i].ret, ProgramSpec) else str(u)
except Exception: return "issue in pyrender"
def uop_to_json(x:UOp) -> dict[int, dict]:
@@ -93,16 +89,16 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
return graph
@functools.cache
def _reconstruct(a:int, i:int):
op, dtype, src, arg, *rest = traces[i][2][a]
arg = type(arg)(_reconstruct(arg.ast, i), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s, i) for s in src), arg, *rest)
def _reconstruct(a:int):
op, dtype, src, arg, *rest = trace.uop_fields[a]
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink, i)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num, i)] = u1 = _reconstruct(u1_num, i)
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
@@ -145,7 +141,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
name, info = e.name, None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
if isinstance(p:=traces[ref][0].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \
f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}"
elif isinstance(e.name, TracingKey):
@@ -226,7 +222,7 @@ def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
def get_render(ctx:list[str], fmt:list[str]):
if not isinstance(prg:=traces[int(ctx[0])][0].ret, ProgramSpec): return
if not isinstance(prg:=trace.keys[int(ctx[0])].ret, ProgramSpec): return
if fmt[0] == "src": return json.dumps({"src":prg.src, "lang":"cpp"}).encode()
lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
@@ -256,7 +252,7 @@ class Handler(BaseHTTPRequestHandler):
elif (query:=parse_qs(url.query)):
if url.path == "/render": ret, content_type = get_render(**query), "application/json"
else:
try: return self.stream_json(get_details(traces[i:=int(query["ctx"][0])][1][int(query["idx"][0])], i))
try: return self.stream_json(get_details(trace.rewrites[i:=int(query["ctx"][0])][int(query["idx"][0])], i))
except KeyError: status_code = 404
elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
@@ -313,7 +309,7 @@ if __name__ == "__main__":
st = time.perf_counter()
print("*** viz is starting")
ctxs = get_metadata(args.kernels)
ctxs = get_metadata(trace:=args.kernels)
profile_ret = get_profile(args.profile)
server = TCPServerWithReuse(('', PORT), Handler)