Files
tinygrad/viz/index.html
2024-09-23 16:44:19 +08:00

347 lines
10 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<title>tinygrad viz</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" href="favicon.svg" type="image/svg+xml">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Noto+Sans+HK:wght@100..900&display=swap" rel="stylesheet">
<script src="https://d3js.org/d3.v5.min.js" charset="utf-8"></script>
<script src="https://dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
<style>
* {
box-sizing: border-box;
margin-block-start: initial;
margin-block-end: initial;
}
html, body {
background-color: #191919;
color: #EEEEEE;
margin: 0;
padding: 0;
width: 100%;
height: 100%;
font-family: "Noto Sans", sans-serif;
font-optical-sizing: auto;
font-weight: 400;
font-style: normal;
font-variation-settings: "wdth" 100;
font-size: 14px;
overflow: hidden;
}
a {
color: #0090FF;
}
ul {
padding: 0;
color: #7B7B7B;
cursor: pointer;
}
ul.active {
color: #EEEEEE;
}
svg {
width: 100%;
height: 100%;
}
svg * {
cursor: default;
user-select: none;
}
.node rect {
stroke: #313131;
stroke-width: 1.5px;
}
.label text {
color: black;
font-weight: 350;
}
.edgePath path {
stroke: #606060;
stroke-width: 1.5px;
}
.graph {
grid-column: span 10;
position: relative;
}
.main-container {
display: grid;
grid-template-columns: repeat(14, 1fr);
gap: 8px;
padding: 12px;
width: 100%;
height: 100%;
background-color: #191919;
}
.container {
height: 100%;
background-color: #111111;
border-radius: 8px;
padding: 8px;
}
.container > * + * {
margin-top: 12px;
}
.kernel-list {
grid-column: span 1;
display: flex;
flex-direction: column;
overflow-y: auto;
}
.kernel-list > ul > ul {
padding-left: 4px;
}
.kernel-list > ul > * + * {
margin-top: 4px;
}
.metadata {
grid-column: span 3;
overflow-y: auto;
}
.rewrite-list {
display: flex;
flex-wrap: wrap;
}
.rewrite-list > * + * {
margin-left: 4px;
}
.wrap {
word-wrap: break-word;
white-space: pre-wrap;
}
.code-block {
max-height: 30%;
overflow-y: auto;
background-color: #191919;
border-radius: 8px;
padding: 8px;
}
.spinner {
position: absolute;
left: 50%;
top: 50%;
width: 32px;
height: 32px;
border: 4px solid #B4B4B4;
border-bottom-color: transparent;
border-radius: 50%;
display: inline-block;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
.status {
color: #EC5D5E;
font-weight: bold;
}
</style>
</head>
<body>
<div class="main-container">
<div class="container kernel-list"></div>
<div class="container graph">
<div class="status"></div>
<svg>
<g id="render"></g>
</svg>
</div>
<div class="container metadata"></div>
</div>
<script>
function renderGraph(graph, additions) {
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: #2ea04326" : "display: none;"});
for ([k,u] of Object.entries(graph)) {
g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
for (src of u[2]) {
g.setEdge(src, k)
}
if (additions.includes(parseInt(k))) {
g.setParent(k, "addition");
}
}
const svg = d3.select("svg");
const inner = svg.select("g");
var zoom = d3.zoom()
.scaleExtent([0.05, 2])
.on("zoom", () => {
const transform = d3.event.transform;
inner.attr("transform", transform);
});
svg.call(zoom);
const render = new dagreD3.render();
render(inner, g);
}
async function checkStatus() {
active = true;
try {
active = (await fetch("/")).ok;
} catch {
active = false;
}
document.querySelector('.status').textContent = active ? `` : `Connection to localhost:8000 failed, is VIZ=1 running?`;
}
var ret = {};
var cache = {};
var kernels = null;
var currentUOp = 0;
var currentKernel = 0;
var currentRewrite = 0;
var expandKernel = false;
async function main() {
checkStatus();
// ** kernel list
if (kernels == null) {
kernels = await (await fetch("/kernels")).json();
currentKernel = 0;
}
const kernelList = document.querySelector(".container.kernel-list");
kernelList.innerHTML = "";
kernels.forEach((k, i) => {
kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "" });
const p = Object.assign(document.createElement("p"), {id: `kernel-${k.name}`, innerText: k.name})
kernelUl.appendChild(p)
k.ctxs.forEach((u, j) => {
const rwUl = Object.assign(document.createElement("ul"), { innerText: u, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
rwUl.onclick = (e) => {
e.stopPropagation();
currentUOp = j;
currentKernel = i;
currentRewrite = 0;
main();
}
kernelUl.appendChild(rwUl)
})
p.onclick = (e) => {
if (i === currentKernel) {
expandKernel = !expandKernel;
main();
return;
}
currentKernel = i;
currentUOp = 0;
currentRewrite = 0;
expandKernel = true;
main();
}
kernelList.appendChild(kernelUl);
});
// ** uop graph
cacheKey = `${currentKernel}-${currentUOp}`;
if (cacheKey in cache) {
ret = cache[cacheKey];
}
else {
ret = await (await fetch(`/graph?kernel_idx=${currentKernel}&uop_idx=${currentUOp}`)).json();
cache[cacheKey] = ret;
}
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
ret[0].extra[currentRewrite].forEach((e, i) => {
if (e.length == 0) return;
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${e}</code>`, className: "code-block" });
metadata.appendChild(pre);
})
if (kernels[currentKernel].code !== "") {
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${kernels[currentKernel].code}</code>`, className: "code-block" });
metadata.appendChild(pre);
}
// ** rewrite list
if (ret[0].graphs.length > 1) {
const rewriteList = Object.assign(document.createElement("div"), { className: "rewrite-list" })
metadata.appendChild(rewriteList);
ret[0].graphs.forEach((rw, i) => {
gUl = Object.assign(document.createElement("ul"), { innerText: i, key: `rewrite-${i}` });
rewriteList.appendChild(gUl);
if (i === currentRewrite) {
gUl.classList.add("active");
if (i !== 0) {
const [pattern, loc, diff] = ret[0].diffs[i-1];
const pre = Object.assign(document.createElement("pre"), { className: "code-block wrap" });
const parts = loc.join(":").split("/");
pre.appendChild(Object.assign(document.createElement("a"), { innerText: parts[parts.length-1]+"\n\n", href: "vscode://file"+parts.join("/") }));
pre.appendChild(Object.assign(document.createElement("code"), { innerText: pattern }));
metadata.appendChild(pre);
const diffHtml = diff.map((line) => {
const color = line.startsWith("+") ? "#30A46C" : line.startsWith("-") ? "#E5484D" : "#EEEEEE";
return `<span style="color: ${color};">${line}</span>`;
}).join("<br>");
metadata.appendChild(Object.assign(document.createElement("pre"), { innerHTML: `<code>${diffHtml}</code>`, className: "wrap" }));
}
}
gUl.addEventListener("click", () => {
currentRewrite = i;
main();
});
})
} else {
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: "No rewrites" }));
}
}
document.addEventListener("keydown", async function(event) {
// up and down change the UOp or kernel from the list
if (!expandKernel) {
if (event.key == "ArrowUp") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.max(0, currentKernel-1)
return main()
}
if (event.key == "ArrowDown") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
return main()
}
}
if (event.key == "Enter") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
expandKernel = !expandKernel;
main();
}
if (event.key == "ArrowUp") {
event.preventDefault()
currentRewrite = 0;
currentUOp = Math.max(0, currentUOp-1)
main()
}
if (event.key == "ArrowDown") {
event.preventDefault()
currentRewrite = 0;
const totalUOps = kernels[currentKernel].ctxs.length-1;
currentUOp = Math.min(totalUOps, currentUOp+1)
main()
}
// left and right go through rewrites in a single UOp
if (event.key == "ArrowLeft") {
event.preventDefault()
currentRewrite = Math.max(0, currentRewrite-1)
main()
}
if (event.key == "ArrowRight") {
event.preventDefault()
const totalRewrites = ret[0].graphs.length-1;
currentRewrite = Math.min(totalRewrites, currentRewrite+1)
main()
}
})
setInterval(checkStatus, 5000);
main()
</script>
</body>
</html>