cleanup viz server (#12688)

This commit is contained in:
qazal
2025-10-15 15:58:36 +08:00
committed by GitHub
parent aa81bde150
commit f0268d13f6
3 changed files with 31 additions and 31 deletions

View File

@@ -16,10 +16,11 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
return sink return sink
# real VIZ=1 pickles these tracked values # real VIZ=1 pickles these tracked values
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt, RewriteTrace
traces = [(tracked_keys, tracked_ctxs, uop_fields)] from tinygrad.viz import serve
serve.trace = RewriteTrace(tracked_keys, tracked_ctxs, uop_fields)
from tinygrad.viz.serve import get_metadata, uop_to_json, get_details from tinygrad.viz.serve import get_metadata, uop_to_json, get_details
def get_viz_list(): return get_metadata(traces) def get_viz_list(): return get_metadata(serve.trace)
def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]: def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
lst = get_viz_list() lst = get_viz_list()
assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}" assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}"

View File

@@ -1043,6 +1043,9 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += time.perf_counter()-st match_stats[p][2] += time.perf_counter()-st
return None return None
@dataclass(frozen=True)
class RewriteTrace: keys:list[TracingKey]; rewrites:list[list[TrackedGraphRewrite]]; uop_fields:dict[int, tuple] # noqa: E702
if TRACK_MATCH_STATS or PROFILE: if TRACK_MATCH_STATS or PROFILE:
PatternMatcher = TrackedPatternMatcher # type: ignore PatternMatcher = TrackedPatternMatcher # type: ignore
import atexit import atexit
@@ -1051,7 +1054,7 @@ if TRACK_MATCH_STATS or PROFILE:
if TRACK_MATCH_STATS >= 2: if TRACK_MATCH_STATS >= 2:
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f: with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
pickle.dump([(tracked_keys, tracked_ctxs, uop_fields)], f) pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value): if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
ret = [0,0,0.0,0.0] ret = [0,0,0.0,0.0]

View File

@@ -7,7 +7,7 @@ from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, Generator from typing import Any, TypedDict, Generator
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
from tinygrad.renderer import ProgramSpec from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
@@ -24,26 +24,22 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
# VIZ API # VIZ API
# ** Metadata for a track_rewrites scope # ** list all saved rewrites
ref_map:dict[Any, int] = {} ref_map:dict[Any, int] = {}
traces:dict[int, tuple] = {} def get_metadata(t:RewriteTrace) -> list[dict]:
def get_metadata(trace_bufs:list[tuple]) -> list[dict]:
ret = [] ret = []
for keys,contexts,uop_fields in trace_bufs: for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
for k,v in zip(keys, contexts): steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc),
traces[i:=len(traces)] = (k, v, uop_fields) "query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)]
steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc), if isinstance(k.ret, ProgramSpec):
"query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)] steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
ret.append({"name":k.display_name, "steps":steps}) steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
# program spec metadata for key in k.keys: ref_map[key] = i
if isinstance(k.ret, ProgramSpec): ret.append({"name":k.display_name, "steps":steps})
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
for key in k.keys: ref_map[key] = i
return ret return ret
# ** Complete rewrite details for a graph_rewrite call # ** get the complete UOp graphs for one rewrite
class GraphRewriteDetails(TypedDict): class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step graph: dict # JSON serialized UOp for this rewrite step
@@ -56,7 +52,7 @@ def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")" def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
def pystr(u:UOp, i:int) -> str: def pystr(u:UOp, i:int) -> str:
try: try:
return "\n".join(pyrender(u)) if isinstance(traces[i][0].ret, ProgramSpec) else str(u) return "\n".join(pyrender(u)) if isinstance(trace.keys[i].ret, ProgramSpec) else str(u)
except Exception: return "issue in pyrender" except Exception: return "issue in pyrender"
def uop_to_json(x:UOp) -> dict[int, dict]: def uop_to_json(x:UOp) -> dict[int, dict]:
@@ -93,16 +89,16 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
return graph return graph
@functools.cache @functools.cache
def _reconstruct(a:int, i:int): def _reconstruct(a:int):
op, dtype, src, arg, *rest = traces[i][2][a] op, dtype, src, arg, *rest = trace.uop_fields[a]
arg = type(arg)(_reconstruct(arg.ast, i), arg.metadata) if op is Ops.KERNEL else arg arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s, i) for s in src), arg, *rest) return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]: def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink, i)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None} yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {} replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches): for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num, i)] = u1 = _reconstruct(u1_num, i) replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces) try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc) match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
@@ -145,7 +141,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
name, info = e.name, None name, info = e.name, None
if (ref:=ref_map.get(name)) is not None: if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"] name = ctxs[ref]["name"]
if isinstance(p:=traces[ref][0].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None: if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \ info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \
f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}" f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}"
elif isinstance(e.name, TracingKey): elif isinstance(e.name, TracingKey):
@@ -226,7 +222,7 @@ def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary} return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
def get_render(ctx:list[str], fmt:list[str]): def get_render(ctx:list[str], fmt:list[str]):
if not isinstance(prg:=traces[int(ctx[0])][0].ret, ProgramSpec): return if not isinstance(prg:=trace.keys[int(ctx[0])].ret, ProgramSpec): return
if fmt[0] == "src": return json.dumps({"src":prg.src, "lang":"cpp"}).encode() if fmt[0] == "src": return json.dumps({"src":prg.src, "lang":"cpp"}).encode()
lib = (compiler:=Device[prg.device].compiler).compile(prg.src) lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib) with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
@@ -256,7 +252,7 @@ class Handler(BaseHTTPRequestHandler):
elif (query:=parse_qs(url.query)): elif (query:=parse_qs(url.query)):
if url.path == "/render": ret, content_type = get_render(**query), "application/json" if url.path == "/render": ret, content_type = get_render(**query), "application/json"
else: else:
try: return self.stream_json(get_details(traces[i:=int(query["ctx"][0])][1][int(query["idx"][0])], i)) try: return self.stream_json(get_details(trace.rewrites[i:=int(query["ctx"][0])][int(query["idx"][0])], i))
except KeyError: status_code = 404 except KeyError: status_code = 404
elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json" elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
@@ -313,7 +309,7 @@ if __name__ == "__main__":
st = time.perf_counter() st = time.perf_counter()
print("*** viz is starting") print("*** viz is starting")
ctxs = get_metadata(args.kernels) ctxs = get_metadata(trace:=args.kernels)
profile_ret = get_profile(args.profile) profile_ret = get_profile(args.profile)
server = TCPServerWithReuse(('', PORT), Handler) server = TCPServerWithReuse(('', PORT), Handler)