mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
cleanup viz server (#12688)
This commit is contained in:
@@ -16,10 +16,11 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
|
|||||||
return sink
|
return sink
|
||||||
|
|
||||||
# real VIZ=1 pickles these tracked values
|
# real VIZ=1 pickles these tracked values
|
||||||
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt
|
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, _name_cnt, RewriteTrace
|
||||||
traces = [(tracked_keys, tracked_ctxs, uop_fields)]
|
from tinygrad.viz import serve
|
||||||
|
serve.trace = RewriteTrace(tracked_keys, tracked_ctxs, uop_fields)
|
||||||
from tinygrad.viz.serve import get_metadata, uop_to_json, get_details
|
from tinygrad.viz.serve import get_metadata, uop_to_json, get_details
|
||||||
def get_viz_list(): return get_metadata(traces)
|
def get_viz_list(): return get_metadata(serve.trace)
|
||||||
def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
|
def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
|
||||||
lst = get_viz_list()
|
lst = get_viz_list()
|
||||||
assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}"
|
assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}"
|
||||||
|
|||||||
@@ -1043,6 +1043,9 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||||||
match_stats[p][2] += time.perf_counter()-st
|
match_stats[p][2] += time.perf_counter()-st
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RewriteTrace: keys:list[TracingKey]; rewrites:list[list[TrackedGraphRewrite]]; uop_fields:dict[int, tuple] # noqa: E702
|
||||||
|
|
||||||
if TRACK_MATCH_STATS or PROFILE:
|
if TRACK_MATCH_STATS or PROFILE:
|
||||||
PatternMatcher = TrackedPatternMatcher # type: ignore
|
PatternMatcher = TrackedPatternMatcher # type: ignore
|
||||||
import atexit
|
import atexit
|
||||||
@@ -1051,7 +1054,7 @@ if TRACK_MATCH_STATS or PROFILE:
|
|||||||
if TRACK_MATCH_STATS >= 2:
|
if TRACK_MATCH_STATS >= 2:
|
||||||
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
||||||
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
||||||
pickle.dump([(tracked_keys, tracked_ctxs, uop_fields)], f)
|
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
|
||||||
if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
||||||
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
|
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
|
||||||
ret = [0,0,0.0,0.0]
|
ret = [0,0,0.0,0.0]
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from http.server import BaseHTTPRequestHandler
|
|||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
from typing import Any, TypedDict, Generator
|
from typing import Any, TypedDict, Generator
|
||||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||||
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
|
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
|
||||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||||
from tinygrad.renderer import ProgramSpec
|
from tinygrad.renderer import ProgramSpec
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
@@ -24,26 +24,22 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||||||
|
|
||||||
# VIZ API
|
# VIZ API
|
||||||
|
|
||||||
# ** Metadata for a track_rewrites scope
|
# ** list all saved rewrites
|
||||||
|
|
||||||
ref_map:dict[Any, int] = {}
|
ref_map:dict[Any, int] = {}
|
||||||
traces:dict[int, tuple] = {}
|
def get_metadata(t:RewriteTrace) -> list[dict]:
|
||||||
def get_metadata(trace_bufs:list[tuple]) -> list[dict]:
|
|
||||||
ret = []
|
ret = []
|
||||||
for keys,contexts,uop_fields in trace_bufs:
|
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
|
||||||
for k,v in zip(keys, contexts):
|
steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc),
|
||||||
traces[i:=len(traces)] = (k, v, uop_fields)
|
"query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)]
|
||||||
steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc),
|
if isinstance(k.ret, ProgramSpec):
|
||||||
"query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)]
|
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
|
||||||
ret.append({"name":k.display_name, "steps":steps})
|
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
|
||||||
# program spec metadata
|
for key in k.keys: ref_map[key] = i
|
||||||
if isinstance(k.ret, ProgramSpec):
|
ret.append({"name":k.display_name, "steps":steps})
|
||||||
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
|
|
||||||
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
|
|
||||||
for key in k.keys: ref_map[key] = i
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# ** Complete rewrite details for a graph_rewrite call
|
# ** get the complete UOp graphs for one rewrite
|
||||||
|
|
||||||
class GraphRewriteDetails(TypedDict):
|
class GraphRewriteDetails(TypedDict):
|
||||||
graph: dict # JSON serialized UOp for this rewrite step
|
graph: dict # JSON serialized UOp for this rewrite step
|
||||||
@@ -56,7 +52,7 @@ def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in
|
|||||||
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
|
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
|
||||||
def pystr(u:UOp, i:int) -> str:
|
def pystr(u:UOp, i:int) -> str:
|
||||||
try:
|
try:
|
||||||
return "\n".join(pyrender(u)) if isinstance(traces[i][0].ret, ProgramSpec) else str(u)
|
return "\n".join(pyrender(u)) if isinstance(trace.keys[i].ret, ProgramSpec) else str(u)
|
||||||
except Exception: return "issue in pyrender"
|
except Exception: return "issue in pyrender"
|
||||||
|
|
||||||
def uop_to_json(x:UOp) -> dict[int, dict]:
|
def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||||
@@ -93,16 +89,16 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||||||
return graph
|
return graph
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def _reconstruct(a:int, i:int):
|
def _reconstruct(a:int):
|
||||||
op, dtype, src, arg, *rest = traces[i][2][a]
|
op, dtype, src, arg, *rest = trace.uop_fields[a]
|
||||||
arg = type(arg)(_reconstruct(arg.ast, i), arg.metadata) if op is Ops.KERNEL else arg
|
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
|
||||||
return UOp(op, dtype, tuple(_reconstruct(s, i) for s in src), arg, *rest)
|
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
|
||||||
|
|
||||||
def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
|
def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
|
||||||
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink, i)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
|
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
|
||||||
replaces: dict[UOp, UOp] = {}
|
replaces: dict[UOp, UOp] = {}
|
||||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||||
replaces[u0:=_reconstruct(u0_num, i)] = u1 = _reconstruct(u1_num, i)
|
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||||
try: new_sink = next_sink.substitute(replaces)
|
try: new_sink = next_sink.substitute(replaces)
|
||||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||||
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
||||||
@@ -145,7 +141,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
|
|||||||
name, info = e.name, None
|
name, info = e.name, None
|
||||||
if (ref:=ref_map.get(name)) is not None:
|
if (ref:=ref_map.get(name)) is not None:
|
||||||
name = ctxs[ref]["name"]
|
name = ctxs[ref]["name"]
|
||||||
if isinstance(p:=traces[ref][0].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
|
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
|
||||||
info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \
|
info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \
|
||||||
f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}"
|
f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}"
|
||||||
elif isinstance(e.name, TracingKey):
|
elif isinstance(e.name, TracingKey):
|
||||||
@@ -226,7 +222,7 @@ def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
|
|||||||
return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
|
return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
|
||||||
|
|
||||||
def get_render(ctx:list[str], fmt:list[str]):
|
def get_render(ctx:list[str], fmt:list[str]):
|
||||||
if not isinstance(prg:=traces[int(ctx[0])][0].ret, ProgramSpec): return
|
if not isinstance(prg:=trace.keys[int(ctx[0])].ret, ProgramSpec): return
|
||||||
if fmt[0] == "src": return json.dumps({"src":prg.src, "lang":"cpp"}).encode()
|
if fmt[0] == "src": return json.dumps({"src":prg.src, "lang":"cpp"}).encode()
|
||||||
lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
|
lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
|
||||||
with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
|
with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
|
||||||
@@ -256,7 +252,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
elif (query:=parse_qs(url.query)):
|
elif (query:=parse_qs(url.query)):
|
||||||
if url.path == "/render": ret, content_type = get_render(**query), "application/json"
|
if url.path == "/render": ret, content_type = get_render(**query), "application/json"
|
||||||
else:
|
else:
|
||||||
try: return self.stream_json(get_details(traces[i:=int(query["ctx"][0])][1][int(query["idx"][0])], i))
|
try: return self.stream_json(get_details(trace.rewrites[i:=int(query["ctx"][0])][int(query["idx"][0])], i))
|
||||||
except KeyError: status_code = 404
|
except KeyError: status_code = 404
|
||||||
elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
|
elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
|
||||||
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
||||||
@@ -313,7 +309,7 @@ if __name__ == "__main__":
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
print("*** viz is starting")
|
print("*** viz is starting")
|
||||||
|
|
||||||
ctxs = get_metadata(args.kernels)
|
ctxs = get_metadata(trace:=args.kernels)
|
||||||
profile_ret = get_profile(args.profile)
|
profile_ret = get_profile(args.profile)
|
||||||
|
|
||||||
server = TCPServerWithReuse(('', PORT), Handler)
|
server = TCPServerWithReuse(('', PORT), Handler)
|
||||||
|
|||||||
Reference in New Issue
Block a user