Files
tinygrad/viz/index.html
qazal fb3fe6f39b better VIZ (#6781)
* ui changes

* make kernels global

* dont save buffers when running VIZ=1

* remove flex in layout

* use os.execv

* del server thread

* server close

* cleanup

* logs cleanup

* rm getenv

* cleanups

* remove global
2024-09-27 18:38:31 +08:00

425 lines
14 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<title>tinygrad viz</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" href="favicon.svg" type="image/svg+xml">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Noto+Sans+HK:wght@100..900&display=swap" rel="stylesheet">
<script src="https://d3js.org/d3.v5.min.js" charset="utf-8"></script>
<script src="https://dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js"></script>
<link rel="stylesheet" href="https://unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
<style>
* {
box-sizing: border-box;
margin-block-start: initial;
margin-block-end: initial;
}
html, body {
background-color: #08090e;
color: #f0f0f5;
margin: 0;
padding: 0;
width: 100%;
height: 100%;
font-family: "Noto Sans", sans-serif;
font-optical-sizing: auto;
font-weight: 400;
font-style: normal;
font-variation-settings: "wdth" 100;
font-size: 14px;
overflow: hidden;
}
a {
color: #4a90e2;
}
ul {
padding: 0;
color: #7c7d85;
white-space: nowrap;
cursor: pointer;
}
ul.active {
color: #f0f0f5;
}
svg {
width: 100%;
height: 100%;
}
svg * {
cursor: default;
user-select: none;
}
.node rect {
stroke: #4a4b57;
stroke-width: 1.5px;
}
.label text {
color: #08090e;
font-weight: 350;
}
.edgePath path {
stroke: #4a4b57;
fill: #4a4b57;
stroke-width: 1.4px;
}
.graph {
width: 70%;
position: relative;
}
.main-container {
display: flex;
padding: 12px;
width: 100%;
height: 100%;
}
.container {
background-color: #0f1018;
height: 100%;
border-radius: 8px;
padding: 8px;
position: relative;
}
.container > * + * {
margin-top: 12px;
}
.rewrite-container > * + * {
margin-top: 12px;
}
.main-container > * + * {
margin-left: 8px;
}
.kernel-list {
width: 10%;
overflow-y: auto;
}
.kernel-list > ul > ul {
padding-left: 4px;
}
.kernel-list > ul > * + * {
margin-top: 4px;
}
.metadata {
width: 20%;
padding-left: 20px;
overflow-y: auto;
}
#resize-handle {
position: absolute;
left: 0;
top: 0;
bottom: 0;
width: 20px;
height: 100%;
cursor: col-resize;
background-color: transparent;
}
.rewrite-list {
display: flex;
flex-wrap: wrap;
}
.rewrite-list > * + * {
margin-left: 4px;
}
.wrap {
word-wrap: break-word;
white-space: pre-wrap;
}
.code-block {
max-height: 30%;
overflow-y: auto;
border-radius: 8px;
padding: 8px;
}
.spinner {
position: absolute;
left: 50%;
top: 50%;
width: 32px;
height: 32px;
border: 4px solid #B4B4B4;
border-bottom-color: transparent;
border-radius: 50%;
display: inline-block;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
.status {
color: #EC5D5E;
font-weight: bold;
}
</style>
</head>
<body>
<div class="main-container">
<div class="container kernel-list"></div>
<div class="container graph">
<div class="status"></div>
<svg>
<g id="render"></g>
</svg>
</div>
<div class="container metadata"></div>
</div>
<script>
// extra definitions for UOps
hljs.registerLanguage("python", (hljs) => ({
...hljs.getLanguage("python"),
case_insensitive: false,
contains: [
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" },
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" },
{ begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" },
{ begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true },
...hljs.getLanguage("python").contains,
]
}));
// extra definitions for float4
hljs.registerLanguage("cpp", (hljs) => ({
...hljs.getLanguage('cpp'),
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));
function renderGraph(graph, additions) {
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5); rx: 8; ry: 8;" : "display: none;"});
for ([k,u] of Object.entries(graph)) {
g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
for (src of u[2]) {
g.setEdge(src, k, {curve: d3.curveBasis})
}
if (additions.includes(parseInt(k))) {
g.setParent(k, "addition");
}
}
const svg = d3.select("svg");
const inner = svg.select("g");
var zoom = d3.zoom()
.scaleExtent([0.05, 2])
.on("zoom", () => {
const transform = d3.event.transform;
inner.attr("transform", transform);
});
svg.call(zoom);
const render = new dagreD3.render();
render(inner, g);
}
async function checkStatus() {
active = true;
try {
active = (await fetch("/")).ok;
} catch {
active = false;
}
document.querySelector('.status').textContent = active ? `` : `Connection to localhost:8000 failed, is VIZ=1 running?`;
}
var ret = {};
var cache = {};
var kernels = null;
var currentUOp = 0;
var currentKernel = 0;
var currentRewrite = 0;
var expandKernel = true;
async function main() {
checkStatus();
// ***** LHS kernels list
if (kernels == null) {
kernels = await (await fetch("/kernels")).json();
currentKernel = 0;
}
const kernelList = document.querySelector(".container.kernel-list");
kernelList.innerHTML = "";
kernels.forEach((k, i) => {
kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "", style: "overflow-x: auto; cursor: initial;" });
const p = Object.assign(document.createElement("p"), {id: `kernel-${k.name}`, innerText: k.name, style: "cursor: pointer;"});
kernelUl.appendChild(p)
k.ctxs.forEach((u, j) => {
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${u.filename} - ${u.match_count} - ${u.matcher_name}`, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
rwUl.onclick = (e) => {
e.stopPropagation();
currentUOp = j;
currentKernel = i;
currentRewrite = 0;
main();
}
kernelUl.appendChild(rwUl)
})
p.onclick = (e) => {
if (i === currentKernel) {
expandKernel = !expandKernel;
main();
return;
}
currentKernel = i;
currentUOp = 0;
currentRewrite = 0;
expandKernel = true;
main();
}
kernelList.appendChild(kernelUl);
});
// ***** UOp graph
cacheKey = `${currentKernel}-${currentUOp}`;
if (cacheKey in cache) {
ret = cache[cacheKey];
}
else {
ret = await (await fetch(`/graph?kernel_idx=${currentKernel}&uop_idx=${currentUOp}`)).json();
cache[cacheKey] = ret;
}
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
// ***** RHS metadata
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc.filename }));
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(ret[0].loc.code)}</code>`, className: "wrap code-block language-python" });
metadata.appendChild(pre);
hljs.highlightElement(pre);
// ** resizer
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
const resizeHandle = document.getElementById("resize-handle");
let startX;
let containerWidth;
let metadataWidth;
resizeHandle.addEventListener("mousedown", (e) => {
e.preventDefault();
metadata.style.userSelect = "none";
document.documentElement.style.cursor = "col-resize";
startX = e.clientX;
containerWidth = document.querySelector(".main-container").getBoundingClientRect().width;
metadataWidth = metadata.getBoundingClientRect().width;
document.documentElement.addEventListener("mousemove", resize, false);
document.documentElement.addEventListener("mouseup", stopResize, false);
});
function resize(e) {
const change = e.clientX - startX;
const newWidth = ((metadataWidth-change) / containerWidth) * 100;
if (newWidth >= 20 && newWidth <= 50) {
metadata.style.width = `${newWidth}%`;
document.querySelector(".graph").style.width = `${100-newWidth-10}%`;
}
}
function stopResize(e) {
document.documentElement.removeEventListener("mousemove", resize, false);
document.documentElement.removeEventListener("mouseup", stopResize, false);
document.documentElement.style.cursor = "initial";
metadata.style.userSelect = "initial";
}
// ** code blocks
ret[0].extra[currentRewrite].forEach((e, i) => {
if (e.length == 0) return;
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(e)}</code>`, className: "code-block language-python" });
hljs.highlightElement(pre);
metadata.appendChild(pre);
})
if (kernels[currentKernel].code !== "") {
const code = kernels[currentKernel].code.replaceAll("<", "&lt;").replaceAll(">", "&gt;");
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${DOMPurify.sanitize(code)}</code>`, className: "code-block language-cpp" });
hljs.highlightElement(pre);
metadata.appendChild(pre);
}
// ** rewrite list
if (ret[0].graphs.length > 1) {
const rewriteList = Object.assign(document.createElement("div"), { className: "rewrite-list" })
metadata.appendChild(rewriteList);
ret[0].graphs.forEach((rw, i) => {
gUl = Object.assign(document.createElement("ul"), { innerText: i, key: `rewrite-${i}` });
rewriteList.appendChild(gUl);
if (i === currentRewrite) {
gUl.classList.add("active");
if (i !== 0) {
const [pattern, loc, diff] = ret[0].diffs[i-1];
const parts = loc.join(":").split("/");
const div = Object.assign(document.createElement("div"), { className: "rewrite-container" });
link = Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n", href: "vscode://file"+parts.join("/"), style: "font-family: monospace; margin: 4px 0;" })
div.appendChild(link);
const pre = Object.assign(document.createElement("pre"), { className: "code-block wrap" });
pre.appendChild(Object.assign(document.createElement("code"), { textContent: DOMPurify.sanitize(pattern), className: "language-python" }));
div.appendChild(pre);
hljs.highlightElement(pre);
metadata.appendChild(div);
const diffHtml = diff.map((line) => {
const color = line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5";
return `<span style="color: ${color};">${line}</span>`;
}).join("<br>");
metadata.appendChild(Object.assign(document.createElement("pre"), { innerHTML: `<code>${diffHtml}</code>`, className: "wrap" }));
}
}
gUl.addEventListener("click", () => {
currentRewrite = i;
main();
});
})
} else {
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${ret[0].loc.filename}.` }));
}
}
document.addEventListener("keydown", async function(event) {
// up and down change the UOp or kernel from the list
if (!expandKernel) {
if (event.key == "ArrowUp") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.max(0, currentKernel-1)
return main()
}
if (event.key == "ArrowDown") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
return main()
}
}
if (event.key == "Enter") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
expandKernel = !expandKernel;
main();
}
if (event.key == "ArrowUp") {
event.preventDefault()
currentRewrite = 0;
currentUOp = Math.max(0, currentUOp-1)
main()
}
if (event.key == "ArrowDown") {
event.preventDefault()
currentRewrite = 0;
const totalUOps = kernels[currentKernel].ctxs.length-1;
currentUOp = Math.min(totalUOps, currentUOp+1)
main()
}
// left and right go through rewrites in a single UOp
if (event.key == "ArrowLeft") {
event.preventDefault()
currentRewrite = Math.max(0, currentRewrite-1)
main()
}
if (event.key == "ArrowRight") {
event.preventDefault()
const totalRewrites = ret[0].graphs.length-1;
currentRewrite = Math.min(totalRewrites, currentRewrite+1)
main()
}
})
main()
</script>
</body>
</html>