viz fold const nodes and UOp/float4 syntax highlight (#6695)

* fold const nodes

* show rewrite count

* hotfix: cpp

* more syntax highlight

* custom language definitions

* only cpp

* small fixups for UPat

* extend python

* cleanups

* rewrites helper

* better message
This commit is contained in:
qazal
2024-09-24 14:36:59 +08:00
committed by GitHub
parent 4bb1694f49
commit 048483ee0b
3 changed files with 61 additions and 24 deletions

View File

@@ -12,7 +12,8 @@
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/c.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js"></script>
<link rel="stylesheet" href="https://unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
<style>
* {
@@ -67,11 +68,6 @@
fill: #4a4b57;
stroke-width: 1.4px;
}
#arrowhead {
stroke: blue;
fill: blue;
stroke-width: 1.5px;
}
.graph {
width: 70%;
position: relative;
@@ -120,7 +116,8 @@
top: 0;
bottom: 0;
width: 20px;
cursor: w-resize;
height: 100%;
cursor: col-resize;
background-color: transparent;
}
.rewrite-list {
@@ -178,6 +175,23 @@
<div class="container metadata"></div>
</div>
<script>
// extra definitions for UOps
hljs.registerLanguage("python", (hljs) => ({
...hljs.getLanguage("python"),
case_insensitive: false,
contains: [
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" },
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" },
{ begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" },
{ begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true },
...hljs.getLanguage("python").contains,
]
}));
// extra definitions for float4
hljs.registerLanguage("cpp", (hljs) => ({
...hljs.getLanguage('cpp'),
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));
function renderGraph(graph, additions) {
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5); rx: 8; ry: 8;" : "display: none;"});
@@ -220,7 +234,7 @@
var expandKernel = false;
async function main() {
checkStatus();
// ** kernel list
// ***** LHS kernels list
if (kernels == null) {
kernels = await (await fetch("/kernels")).json();
currentKernel = 0;
@@ -257,7 +271,7 @@
}
kernelList.appendChild(kernelUl);
});
// ** uop graph
// ***** UOp graph
cacheKey = `${currentKernel}-${currentUOp}`;
if (cacheKey in cache) {
ret = cache[cacheKey];
@@ -267,18 +281,20 @@
cache[cacheKey] = ret;
}
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
// ***** RHS metadata
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
// ** resizer
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
const resizeHandle = document.getElementById("resize-handle");
let startX;
let containerWidth;
let metadataWidth;
resizeHandle.addEventListener("mousedown", (e) => {
e.preventDefault();
metadata.style.userSelect = "none";
document.documentElement.style.cursor = "col-resize";
startX = e.clientX;
containerWidth = document.querySelector(".main-container").getBoundingClientRect().width;
metadataWidth = metadata.getBoundingClientRect().width;
@@ -296,17 +312,19 @@
function stopResize(e) {
document.documentElement.removeEventListener("mousemove", resize, false);
document.documentElement.removeEventListener("mouseup", stopResize, false);
document.documentElement.style.cursor = "initial";
metadata.style.userSelect = "initial";
}
// ** code blocks
ret[0].extra[currentRewrite].forEach((e, i) => {
if (e.length == 0) return;
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${e}</code>`, className: "code-block language-python" });
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(e)}</code>`, className: "code-block language-python" });
hljs.highlightElement(pre);
metadata.appendChild(pre);
})
if (kernels[currentKernel].code !== "") {
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${kernels[currentKernel].code}</code>`, className: "code-block" });
const code = kernels[currentKernel].code.replaceAll("<", "&lt;").replaceAll(">", "&gt;");
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(code)}</code>`, className: "code-block language-cpp" });
hljs.highlightElement(pre);
metadata.appendChild(pre);
}
@@ -326,7 +344,7 @@
link = Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n", href: "vscode://file"+parts.join("/"), style: "font-family: monospace; margin: 4px 0;" })
div.appendChild(link);
const pre = Object.assign(document.createElement("pre"), { className: "code-block wrap" });
pre.appendChild(Object.assign(document.createElement("code"), { textContent: pattern, className: "language-python" }));
pre.appendChild(Object.assign(document.createElement("code"), { textContent: DOMPurify.sanitize(pattern), className: "language-python" }));
div.appendChild(pre);
hljs.highlightElement(pre);
metadata.appendChild(div);
@@ -343,7 +361,7 @@
});
})
} else {
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: "No rewrites" }));
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc}.` }));
}
}
document.addEventListener("keydown", async function(event) {

View File

@@ -7,13 +7,16 @@ from urllib.parse import parse_qs, urlparse
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad import Device
from tinygrad.helpers import Context, getenv, to_function_name
from tinygrad.ops import TrackedRewriteContext, UOp
from tinygrad.ops import TrackedRewriteContext, UOp, UOps
from tinygrad.engine.graph import uops_colors, word_wrap
from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import full_ast_rewrite
# **** /graph - detailed UOp + rewrites
# NOTE: UPats in ops.py are spec
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
@dataclass(frozen=True)
class UOpRet:
loc: str
@@ -28,16 +31,15 @@ class UOpRet:
extra: List[List[str]] = [[str(ctx.sink)]]
additions: List[List[int]] = [[]]
seen_replaces: Dict[bytes, UOp] = {}
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
if pattern.location[0].split("/")[-1] == "ops.py": continue
for i, (first, rewritten, pattern) in enumerate(graph_rewrites(ctx)):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first.key] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
# update ret data
additions.append([id(x) for x in rewritten.sparents])
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
additions.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
diffs.append((pattern.printable(), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, uops, diffs, extra, additions)
@@ -47,11 +49,14 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
if u.op is UOps.CONST and u is not x: continue
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
if getenv("WITH_SHAPE"):
with contextlib.suppress(Exception): # if the UOp is indexed already it's fine
if u.st is not None: label += f"\n{u.st.shape}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src], str(u.arg), uops_colors.get(u.op, "#ffffff"))
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
@@ -68,7 +73,7 @@ class KernelRet:
code: str
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
def to_json(self) -> Dict:
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs.values()]}
return {"name":self.name, "code":self.code, "ctxs":[f"{x.loc} - {len(graph_rewrites(x))}" for x in self.ctxs.values()]}
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
ret: Dict[str, KernelRet] = {}

View File

@@ -9,7 +9,7 @@ from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding
from test.external.process_replay.helpers import print_diff
from viz.serve import UOpRet, load_kernels
from viz.serve import UOpRet, load_kernels, uop_to_json
class TestViz(unittest.TestCase):
def tearDown(self) -> None:
@@ -137,5 +137,19 @@ class TestViz(unittest.TestCase):
schedule_ctxs = [x for x in kernels[1].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
self.assertEqual(len(schedule_ctxs), 0)
def test_fold_const_nodes(self):
a = Tensor.empty(4, 4)+2
contexts.clear()
sink = a.schedule()[-1].ast
ret = uop_to_json(sink)
for v in ret.values(): print(v)
assert not any(v[0].startswith("CONST") for v in ret.values())
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
def test_no_fold_single_const(self):
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
ret = uop_to_json(node)
assert len(ret) == 1
if __name__ == "__main__":
unittest.main()