mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
viz fold const nodes and UOp/float4 syntax highlight (#6695)
* fold const nodes * show rewrite count * hotfix: cpp * more syntax highlight * custom language definitions * only cpp * small fixups for UPat * extend python * cleanups * rewrites helper * better message
This commit is contained in:
@@ -12,7 +12,8 @@
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/c.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js"></script>
|
||||
<link rel="stylesheet" href="https://unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
|
||||
<style>
|
||||
* {
|
||||
@@ -67,11 +68,6 @@
|
||||
fill: #4a4b57;
|
||||
stroke-width: 1.4px;
|
||||
}
|
||||
#arrowhead {
|
||||
stroke: blue;
|
||||
fill: blue;
|
||||
stroke-width: 1.5px;
|
||||
}
|
||||
.graph {
|
||||
width: 70%;
|
||||
position: relative;
|
||||
@@ -120,7 +116,8 @@
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
width: 20px;
|
||||
cursor: w-resize;
|
||||
height: 100%;
|
||||
cursor: col-resize;
|
||||
background-color: transparent;
|
||||
}
|
||||
.rewrite-list {
|
||||
@@ -178,6 +175,23 @@
|
||||
<div class="container metadata"></div>
|
||||
</div>
|
||||
<script>
|
||||
// extra definitions for UOps
|
||||
hljs.registerLanguage("python", (hljs) => ({
|
||||
...hljs.getLanguage("python"),
|
||||
case_insensitive: false,
|
||||
contains: [
|
||||
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" },
|
||||
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" },
|
||||
{ begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" },
|
||||
{ begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true },
|
||||
...hljs.getLanguage("python").contains,
|
||||
]
|
||||
}));
|
||||
// extra definitions for float4
|
||||
hljs.registerLanguage("cpp", (hljs) => ({
|
||||
...hljs.getLanguage('cpp'),
|
||||
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
|
||||
}));
|
||||
function renderGraph(graph, additions) {
|
||||
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5); rx: 8; ry: 8;" : "display: none;"});
|
||||
@@ -220,7 +234,7 @@
|
||||
var expandKernel = false;
|
||||
async function main() {
|
||||
checkStatus();
|
||||
// ** kernel list
|
||||
// ***** LHS kernels list
|
||||
if (kernels == null) {
|
||||
kernels = await (await fetch("/kernels")).json();
|
||||
currentKernel = 0;
|
||||
@@ -257,7 +271,7 @@
|
||||
}
|
||||
kernelList.appendChild(kernelUl);
|
||||
});
|
||||
// ** uop graph
|
||||
// ***** UOp graph
|
||||
cacheKey = `${currentKernel}-${currentUOp}`;
|
||||
if (cacheKey in cache) {
|
||||
ret = cache[cacheKey];
|
||||
@@ -267,18 +281,20 @@
|
||||
cache[cacheKey] = ret;
|
||||
}
|
||||
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
|
||||
// ***** RHS metadata
|
||||
const metadata = document.querySelector(".container.metadata");
|
||||
metadata.innerHTML = "";
|
||||
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
|
||||
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
|
||||
// ** resizer
|
||||
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
|
||||
const resizeHandle = document.getElementById("resize-handle");
|
||||
|
||||
let startX;
|
||||
let containerWidth;
|
||||
let metadataWidth;
|
||||
resizeHandle.addEventListener("mousedown", (e) => {
|
||||
e.preventDefault();
|
||||
metadata.style.userSelect = "none";
|
||||
document.documentElement.style.cursor = "col-resize";
|
||||
startX = e.clientX;
|
||||
containerWidth = document.querySelector(".main-container").getBoundingClientRect().width;
|
||||
metadataWidth = metadata.getBoundingClientRect().width;
|
||||
@@ -296,17 +312,19 @@
|
||||
function stopResize(e) {
|
||||
document.documentElement.removeEventListener("mousemove", resize, false);
|
||||
document.documentElement.removeEventListener("mouseup", stopResize, false);
|
||||
document.documentElement.style.cursor = "initial";
|
||||
metadata.style.userSelect = "initial";
|
||||
}
|
||||
|
||||
// ** code blocks
|
||||
ret[0].extra[currentRewrite].forEach((e, i) => {
|
||||
if (e.length == 0) return;
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${e}</code>`, className: "code-block language-python" });
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(e)}</code>`, className: "code-block language-python" });
|
||||
hljs.highlightElement(pre);
|
||||
metadata.appendChild(pre);
|
||||
})
|
||||
if (kernels[currentKernel].code !== "") {
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${kernels[currentKernel].code}</code>`, className: "code-block" });
|
||||
const code = kernels[currentKernel].code.replaceAll("<", "<").replaceAll(">", ">");
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(code)}</code>`, className: "code-block language-cpp" });
|
||||
hljs.highlightElement(pre);
|
||||
metadata.appendChild(pre);
|
||||
}
|
||||
@@ -326,7 +344,7 @@
|
||||
link = Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n", href: "vscode://file"+parts.join("/"), style: "font-family: monospace; margin: 4px 0;" })
|
||||
div.appendChild(link);
|
||||
const pre = Object.assign(document.createElement("pre"), { className: "code-block wrap" });
|
||||
pre.appendChild(Object.assign(document.createElement("code"), { textContent: pattern, className: "language-python" }));
|
||||
pre.appendChild(Object.assign(document.createElement("code"), { textContent: DOMPurify.sanitize(pattern), className: "language-python" }));
|
||||
div.appendChild(pre);
|
||||
hljs.highlightElement(pre);
|
||||
metadata.appendChild(div);
|
||||
@@ -343,7 +361,7 @@
|
||||
});
|
||||
})
|
||||
} else {
|
||||
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: "No rewrites" }));
|
||||
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc}.` }));
|
||||
}
|
||||
}
|
||||
document.addEventListener("keydown", async function(event) {
|
||||
|
||||
19
viz/serve.py
19
viz/serve.py
@@ -7,13 +7,16 @@ from urllib.parse import parse_qs, urlparse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad import Device
|
||||
from tinygrad.helpers import Context, getenv, to_function_name
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.schedule import full_ast_rewrite
|
||||
|
||||
# **** /graph - detailed UOp + rewrites
|
||||
|
||||
# NOTE: UPats in ops.py are spec
|
||||
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UOpRet:
|
||||
loc: str
|
||||
@@ -28,16 +31,15 @@ class UOpRet:
|
||||
extra: List[List[str]] = [[str(ctx.sink)]]
|
||||
additions: List[List[int]] = [[]]
|
||||
seen_replaces: Dict[bytes, UOp] = {}
|
||||
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
|
||||
if pattern.location[0].split("/")[-1] == "ops.py": continue
|
||||
for i, (first, rewritten, pattern) in enumerate(graph_rewrites(ctx)):
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
seen_replaces[first.key] = rewritten
|
||||
new_sink = replace_uop(uops[-1], {**seen_replaces})
|
||||
# sanity check
|
||||
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
|
||||
# update ret data
|
||||
additions.append([id(x) for x in rewritten.sparents])
|
||||
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
additions.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
|
||||
diffs.append((pattern.printable(), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
return UOpRet(ctx.loc, uops, diffs, extra, additions)
|
||||
@@ -47,11 +49,14 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
|
||||
for u in x.sparents:
|
||||
if u.op is UOps.CONST and u is not x: continue
|
||||
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
||||
for idx,x in enumerate(u.src):
|
||||
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
|
||||
if getenv("WITH_SHAPE"):
|
||||
with contextlib.suppress(Exception): # if the UOp is indexed already it's fine
|
||||
if u.st is not None: label += f"\n{u.st.shape}"
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
return graph
|
||||
|
||||
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
|
||||
@@ -68,7 +73,7 @@ class KernelRet:
|
||||
code: str
|
||||
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
|
||||
def to_json(self) -> Dict:
|
||||
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs.values()]}
|
||||
return {"name":self.name, "code":self.code, "ctxs":[f"{x.loc} - {len(graph_rewrites(x))}" for x in self.ctxs.values()]}
|
||||
|
||||
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
|
||||
ret: Dict[str, KernelRet] = {}
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import UOpRet, load_kernels
|
||||
from viz.serve import UOpRet, load_kernels, uop_to_json
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
@@ -137,5 +137,19 @@ class TestViz(unittest.TestCase):
|
||||
schedule_ctxs = [x for x in kernels[1].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
|
||||
self.assertEqual(len(schedule_ctxs), 0)
|
||||
|
||||
def test_fold_const_nodes(self):
|
||||
a = Tensor.empty(4, 4)+2
|
||||
contexts.clear()
|
||||
sink = a.schedule()[-1].ast
|
||||
ret = uop_to_json(sink)
|
||||
for v in ret.values(): print(v)
|
||||
assert not any(v[0].startswith("CONST") for v in ret.values())
|
||||
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
def test_no_fold_single_const(self):
|
||||
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
|
||||
ret = uop_to_json(node)
|
||||
assert len(ret) == 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user