mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
viz more info for rewrite location (#6729)
This commit is contained in:
@@ -499,7 +499,7 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
||||
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
@dataclass(frozen=True)
|
||||
class TrackedRewriteContext:
|
||||
loc: str # location that called graph_rewrite
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
kernel_name: Optional[str] = None # the name of the kernel being rewritten
|
||||
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
|
||||
@@ -566,7 +566,7 @@ class RewriteContext:
|
||||
return found
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get()))
|
||||
contexts.append(TrackedRewriteContext(((f:=sys._getframe(1)).f_code.co_filename, f.f_lineno), sink, _CURRENT_KERNEL.get()))
|
||||
return RewriteContext(pm, ctx).rewrite(sink)
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
@@ -246,7 +246,7 @@
|
||||
const p = Object.assign(document.createElement("p"), {id: `kernel-${k.name}`, innerText: k.name})
|
||||
kernelUl.appendChild(p)
|
||||
k.ctxs.forEach((u, j) => {
|
||||
const rwUl = Object.assign(document.createElement("ul"), { innerText: u, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
|
||||
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${u.filename} - ${u.match_count}`, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
|
||||
rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
|
||||
rwUl.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
@@ -284,7 +284,10 @@
|
||||
// ***** RHS metadata
|
||||
const metadata = document.querySelector(".container.metadata");
|
||||
metadata.innerHTML = "";
|
||||
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
|
||||
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc.filename }));
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(ret[0].loc.code)}</code>`, className: "wrap code-block language-python" });
|
||||
metadata.appendChild(pre);
|
||||
hljs.highlightElement(pre);
|
||||
// ** resizer
|
||||
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
|
||||
const resizeHandle = document.getElementById("resize-handle");
|
||||
@@ -361,7 +364,7 @@
|
||||
});
|
||||
})
|
||||
} else {
|
||||
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc}.` }));
|
||||
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc.filename}.` }));
|
||||
}
|
||||
}
|
||||
document.addEventListener("keydown", async function(event) {
|
||||
|
||||
32
viz/serve.py
32
viz/serve.py
@@ -1,13 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Tuple
|
||||
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, re
|
||||
from dataclasses import dataclass, asdict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad import Device
|
||||
from tinygrad.helpers import Context, getenv, to_function_name
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.schedule import full_ast_rewrite
|
||||
@@ -17,9 +17,23 @@ from tinygrad.engine.schedule import full_ast_rewrite
|
||||
# NOTE: UPats in ops.py are spec
|
||||
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RewriteLocation:
|
||||
filename: str
|
||||
code: str
|
||||
matcher_name: Optional[str]
|
||||
match_count: int
|
||||
@staticmethod
|
||||
def from_ctx(ctx:TrackedRewriteContext) -> RewriteLocation:
|
||||
fp, lineno = ctx.loc
|
||||
p = r"graph_rewrite\([^,]+,\s*([^>]+)\)"
|
||||
match = re.search(p, code:=lines(fp)[lineno-1].strip())
|
||||
return RewriteLocation(f"{fp.split('/')[-1]}:{lineno}", code, match.group(1) if match is not None else None, len(graph_rewrites(ctx)))
|
||||
def to_json(self): return asdict(self)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UOpRet:
|
||||
loc: str
|
||||
loc: RewriteLocation
|
||||
graphs: List[UOp] # snapshot of the entire AST after each rewrite
|
||||
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
@@ -42,8 +56,8 @@ class UOpRet:
|
||||
diffs.append((pattern.printable(), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
return UOpRet(ctx.loc, uops, diffs, extra, additions)
|
||||
def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))}
|
||||
return UOpRet(RewriteLocation.from_ctx(ctx), uops, diffs, extra, additions)
|
||||
def to_json(self) -> Dict: return {**asdict(self), "loc":self.loc.to_json(), "graphs": list(map(uop_to_json, self.graphs))}
|
||||
|
||||
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
@@ -71,16 +85,16 @@ def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
|
||||
class KernelRet:
|
||||
name: str
|
||||
code: str
|
||||
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
|
||||
ctxs: Dict[Tuple[Tuple, bytes], TrackedRewriteContext]
|
||||
def to_json(self) -> Dict:
|
||||
return {"name":self.name, "code":self.code, "ctxs":[f"{x.loc} - {len(graph_rewrites(x))}" for x in self.ctxs.values()]}
|
||||
return {"name":self.name, "code":self.code, "ctxs":[RewriteLocation.from_ctx(x).to_json() for x in self.ctxs.values()]}
|
||||
|
||||
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
|
||||
ret: Dict[str, KernelRet] = {}
|
||||
kernel_name = ""
|
||||
code = ""
|
||||
for ctx in contexts:
|
||||
if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py":
|
||||
if ctx.loc[0].split("/")[-1] == "schedule.py":
|
||||
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
|
||||
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
|
||||
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
|
||||
|
||||
@@ -48,8 +48,8 @@ class TestViz(unittest.TestCase):
|
||||
list(lower_schedule(schedule2))
|
||||
ret = load_kernels(contexts)
|
||||
assert len(ret) == 2
|
||||
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret)
|
||||
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret)
|
||||
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc[0]]) != 0 for y in ret)
|
||||
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc[0]]) != 0 for y in ret)
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
@@ -100,7 +100,7 @@ class TestViz(unittest.TestCase):
|
||||
new_sink = graph_rewrite(sink, pm)
|
||||
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
|
||||
self.assert_valid_ctx(contexts)
|
||||
assert all(ctx.loc.split("/")[-1].split(":")[0] == __file__.split("/")[-1] for ctx in contexts)
|
||||
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts)
|
||||
|
||||
@unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules")
|
||||
def test_fuzz_resnet(self):
|
||||
|
||||
Reference in New Issue
Block a user