From 6c69fec1ef9f4409c7a3b2f708fd4d268a48e63d Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:49:40 +0800 Subject: [PATCH] viz more info for rewrite location (#6729) --- tinygrad/ops.py | 4 ++-- viz/index.html | 9 ++++++--- viz/serve.py | 32 +++++++++++++++++++++++--------- viz/test_viz.py | 6 +++--- 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bba1884c1f..045a307ac8 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -499,7 +499,7 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0) match_stats:Dict[UPat, List[Union[int, float]]] = dict() @dataclass(frozen=True) class TrackedRewriteContext: - loc: str # location that called graph_rewrite + loc: Tuple[str, int] # location that called graph_rewrite sink: UOp # the sink passed into the rewrite kernel_name: Optional[str] = None # the name of the kernel being rewritten rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat) @@ -566,7 +566,7 @@ class RewriteContext: return found def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: if TRACK_MATCH_STATS >= 2: - contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get())) + contexts.append(TrackedRewriteContext(((f:=sys._getframe(1)).f_code.co_filename, f.f_lineno), sink, _CURRENT_KERNEL.get())) return RewriteContext(pm, ctx).rewrite(sink) # ***** uop type spec ***** diff --git a/viz/index.html b/viz/index.html index 2400f99c3b..a118ecc026 100644 --- a/viz/index.html +++ b/viz/index.html @@ -246,7 +246,7 @@ const p = Object.assign(document.createElement("p"), {id: `kernel-${k.name}`, innerText: k.name}) kernelUl.appendChild(p) k.ctxs.forEach((u, j) => { - const rwUl = Object.assign(document.createElement("ul"), { innerText: u, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" }) + const rwUl = Object.assign(document.createElement("ul"), { innerText: `${u.filename} - ${u.match_count}`, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" }) rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none"; rwUl.onclick = (e) => { e.stopPropagation(); @@ -284,7 +284,10 @@ // ***** RHS metadata const metadata = document.querySelector(".container.metadata"); metadata.innerHTML = ""; - metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc })); + metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc.filename })); + const pre = Object.assign(document.createElement("pre"), { innerHTML: `${DOMPurify.sanitize(ret[0].loc.code)}`, className: "wrap code-block language-python" }); + metadata.appendChild(pre); + hljs.highlightElement(pre); // ** resizer metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" })); const resizeHandle = document.getElementById("resize-handle"); @@ -361,7 +364,7 @@ }); }) } else { - metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc}.` })); + metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc.filename}.` })); } } document.addEventListener("keydown", async function(event) { diff --git a/viz/serve.py b/viz/serve.py index 4cb6f26ed8..848e3f061f 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 from __future__ import annotations -from typing import Dict, List, Tuple -import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib +from typing import Dict, List, Optional, Tuple +import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, re from dataclasses import dataclass, asdict from urllib.parse import parse_qs, urlparse from http.server import HTTPServer, BaseHTTPRequestHandler from tinygrad import Device from tinygrad.helpers import Context, getenv, to_function_name -from tinygrad.ops import TrackedRewriteContext, UOp, UOps +from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines from tinygrad.engine.graph import uops_colors, word_wrap from tinygrad.engine.realize import get_runner from tinygrad.engine.schedule import full_ast_rewrite @@ -17,9 +17,23 @@ from tinygrad.engine.schedule import full_ast_rewrite # NOTE: UPats in ops.py are spec def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"] +@dataclass(frozen=True) +class RewriteLocation: + filename: str + code: str + matcher_name: Optional[str] + match_count: int + @staticmethod + def from_ctx(ctx:TrackedRewriteContext) -> RewriteLocation: + fp, lineno = ctx.loc + p = r"graph_rewrite\([^,]+,\s*([^>]+)\)" + match = re.search(p, code:=lines(fp)[lineno-1].strip()) + return RewriteLocation(f"{fp.split('/')[-1]}:{lineno}", code, match.group(1) if match is not None else None, len(graph_rewrites(ctx))) + def to_json(self): return asdict(self) + @dataclass(frozen=True) class UOpRet: - loc: str + loc: RewriteLocation graphs: List[UOp] # snapshot of the entire AST after each rewrite diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite extra: List[List[str]] # these become code blocks in the UI @@ -42,8 +56,8 @@ class UOpRet: diffs.append((pattern.printable(), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))) uops.append(new_sink) extra.append([str(new_sink)]) - return UOpRet(ctx.loc, uops, diffs, extra, additions) - def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))} + return UOpRet(RewriteLocation.from_ctx(ctx), uops, diffs, extra, additions) + def to_json(self) -> Dict: return {**asdict(self), "loc":self.loc.to_json(), "graphs": list(map(uop_to_json, self.graphs))} def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: assert isinstance(x, UOp) @@ -71,16 +85,16 @@ def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp: class KernelRet: name: str code: str - ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext] + ctxs: Dict[Tuple[Tuple, bytes], TrackedRewriteContext] def to_json(self) -> Dict: - return {"name":self.name, "code":self.code, "ctxs":[f"{x.loc} - {len(graph_rewrites(x))}" for x in self.ctxs.values()]} + return {"name":self.name, "code":self.code, "ctxs":[RewriteLocation.from_ctx(x).to_json() for x in self.ctxs.values()]} def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]: ret: Dict[str, KernelRet] = {} kernel_name = "" code = "" for ctx in contexts: - if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py": + if ctx.loc[0].split("/")[-1] == "schedule.py": with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, "" if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {}) diff --git a/viz/test_viz.py b/viz/test_viz.py index 1511669b49..e14cd68699 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -48,8 +48,8 @@ class TestViz(unittest.TestCase): list(lower_schedule(schedule2)) ret = load_kernels(contexts) assert len(ret) == 2 - assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret) - assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret) + assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc[0]]) != 0 for y in ret) + assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc[0]]) != 0 for y in ret) def test_gemm_diff(self): x = Tensor.empty(64, 64).realize() @@ -100,7 +100,7 @@ class TestViz(unittest.TestCase): new_sink = graph_rewrite(sink, pm) if DEBUG >= 4: print_diff(sink, new_sink, unified=0) self.assert_valid_ctx(contexts) - assert all(ctx.loc.split("/")[-1].split(":")[0] == __file__.split("/")[-1] for ctx in contexts) + assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts) @unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules") def test_fuzz_resnet(self):