viz more info for rewrite location (#6729)

This commit is contained in:
qazal
2024-09-25 14:49:40 +08:00
committed by GitHub
parent 39f78619ff
commit 6c69fec1ef
4 changed files with 34 additions and 17 deletions

View File

@@ -499,7 +499,7 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
@dataclass(frozen=True)
class TrackedRewriteContext:
loc: str # location that called graph_rewrite
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
kernel_name: Optional[str] = None # the name of the kernel being rewritten
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
@@ -566,7 +566,7 @@ class RewriteContext:
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
if TRACK_MATCH_STATS >= 2:
contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get()))
contexts.append(TrackedRewriteContext(((f:=sys._getframe(1)).f_code.co_filename, f.f_lineno), sink, _CURRENT_KERNEL.get()))
return RewriteContext(pm, ctx).rewrite(sink)
# ***** uop type spec *****

View File

@@ -246,7 +246,7 @@
const p = Object.assign(document.createElement("p"), {id: `kernel-${k.name}`, innerText: k.name})
kernelUl.appendChild(p)
k.ctxs.forEach((u, j) => {
const rwUl = Object.assign(document.createElement("ul"), { innerText: u, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${u.filename} - ${u.match_count}`, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
rwUl.onclick = (e) => {
e.stopPropagation();
@@ -284,7 +284,10 @@
// ***** RHS metadata
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc.filename }));
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(ret[0].loc.code)}</code>`, className: "wrap code-block language-python" });
metadata.appendChild(pre);
hljs.highlightElement(pre);
// ** resizer
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
const resizeHandle = document.getElementById("resize-handle");
@@ -361,7 +364,7 @@
});
})
} else {
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc}.` }));
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc.filename}.` }));
}
}
document.addEventListener("keydown", async function(event) {

View File

@@ -1,13 +1,13 @@
#!/usr/bin/env python3
from __future__ import annotations
from typing import Dict, List, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib
from typing import Dict, List, Optional, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, re
from dataclasses import dataclass, asdict
from urllib.parse import parse_qs, urlparse
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad import Device
from tinygrad.helpers import Context, getenv, to_function_name
from tinygrad.ops import TrackedRewriteContext, UOp, UOps
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import uops_colors, word_wrap
from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import full_ast_rewrite
@@ -17,9 +17,23 @@ from tinygrad.engine.schedule import full_ast_rewrite
# NOTE: UPats in ops.py are spec
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
@dataclass(frozen=True)
class RewriteLocation:
filename: str
code: str
matcher_name: Optional[str]
match_count: int
@staticmethod
def from_ctx(ctx:TrackedRewriteContext) -> RewriteLocation:
fp, lineno = ctx.loc
p = r"graph_rewrite\([^,]+,\s*([^>]+)\)"
match = re.search(p, code:=lines(fp)[lineno-1].strip())
return RewriteLocation(f"{fp.split('/')[-1]}:{lineno}", code, match.group(1) if match is not None else None, len(graph_rewrites(ctx)))
def to_json(self): return asdict(self)
@dataclass(frozen=True)
class UOpRet:
loc: str
loc: RewriteLocation
graphs: List[UOp] # snapshot of the entire AST after each rewrite
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
@@ -42,8 +56,8 @@ class UOpRet:
diffs.append((pattern.printable(), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, uops, diffs, extra, additions)
def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))}
return UOpRet(RewriteLocation.from_ctx(ctx), uops, diffs, extra, additions)
def to_json(self) -> Dict: return {**asdict(self), "loc":self.loc.to_json(), "graphs": list(map(uop_to_json, self.graphs))}
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
@@ -71,16 +85,16 @@ def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
class KernelRet:
name: str
code: str
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
ctxs: Dict[Tuple[Tuple, bytes], TrackedRewriteContext]
def to_json(self) -> Dict:
return {"name":self.name, "code":self.code, "ctxs":[f"{x.loc} - {len(graph_rewrites(x))}" for x in self.ctxs.values()]}
return {"name":self.name, "code":self.code, "ctxs":[RewriteLocation.from_ctx(x).to_json() for x in self.ctxs.values()]}
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
ret: Dict[str, KernelRet] = {}
kernel_name = ""
code = ""
for ctx in contexts:
if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py":
if ctx.loc[0].split("/")[-1] == "schedule.py":
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})

View File

@@ -48,8 +48,8 @@ class TestViz(unittest.TestCase):
list(lower_schedule(schedule2))
ret = load_kernels(contexts)
assert len(ret) == 2
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret)
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret)
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc[0]]) != 0 for y in ret)
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc[0]]) != 0 for y in ret)
def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
@@ -100,7 +100,7 @@ class TestViz(unittest.TestCase):
new_sink = graph_rewrite(sink, pm)
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
self.assert_valid_ctx(contexts)
assert all(ctx.loc.split("/")[-1].split(":")[0] == __file__.split("/")[-1] for ctx in contexts)
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts)
@unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules")
def test_fuzz_resnet(self):