mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
move viz graph to lib/graph [pr] (#9196)
* move viz graph to lib/graph [pr] * add package * share with program
This commit is contained in:
2
setup.py
2
setup.py
@@ -25,7 +25,7 @@ setup(name='tinygrad',
|
||||
packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.runtime.autogen.am', 'tinygrad.codegen', 'tinygrad.nn',
|
||||
'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support',
|
||||
'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape'],
|
||||
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*']},
|
||||
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'lib/**/*']},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
|
||||
<script src="assets/d3js.org/d3.v5.min.js" charset="utf-8"></script>
|
||||
<script src="assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
|
||||
<script src="lib/graph.js"></script>
|
||||
<link rel="stylesheet" href="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
|
||||
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
|
||||
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
|
||||
@@ -233,53 +234,6 @@
|
||||
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
|
||||
}));
|
||||
|
||||
// **** D3
|
||||
function recenterRects(svg, zoom) {
|
||||
const svgBounds = svg.node().getBoundingClientRect();
|
||||
for (const rect of svg.node().querySelectorAll("rect")) {
|
||||
const rectBounds = rect.getBoundingClientRect();
|
||||
const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
|
||||
rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
|
||||
// if there's at least one rect in view we don't do anything
|
||||
if (!outOfBounds) return;
|
||||
}
|
||||
svg.call(zoom.transform, d3.zoomIdentity)
|
||||
}
|
||||
function renderGraph(graph, additions) {
|
||||
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
|
||||
for (const [k,u] of Object.entries(graph)) {
|
||||
let node = {label: u[0], labelType: "text", style: `fill: ${u[2]};`};
|
||||
// for PROGRAM UOp we render the node with a code block
|
||||
if (u[0].includes("PROGRAM")) {
|
||||
const [name, ...rest] = u[0].split("\n");
|
||||
const label = Object.assign(document.createElement("div"));
|
||||
label.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }))
|
||||
label.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
|
||||
node = {label, labelType: "html", style: `fill: ${u[2]}`};
|
||||
}
|
||||
g.setNode(k, node);
|
||||
for (const src of u[1]) {
|
||||
g.setEdge(src, k, {curve: d3.curveBasis})
|
||||
}
|
||||
if (additions.includes(parseInt(k))) {
|
||||
g.setParent(k, "addition");
|
||||
}
|
||||
}
|
||||
const svg = d3.select("#graph-svg");
|
||||
const inner = svg.select("g");
|
||||
var zoom = d3.zoom()
|
||||
.scaleExtent([0.05, 2])
|
||||
.on("zoom", () => {
|
||||
const transform = d3.event.transform;
|
||||
inner.attr("transform", transform);
|
||||
});
|
||||
recenterRects(svg, zoom);
|
||||
svg.call(zoom);
|
||||
const render = new dagreD3.render();
|
||||
render(inner, g);
|
||||
}
|
||||
|
||||
// **** extra helpers
|
||||
const toPath = ([fp, lineno]) => `${fp.replaceAll("\\", "/").split("/").pop()}:${lineno}`;
|
||||
const vsCodeOpener = (parts) => Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n",
|
||||
|
||||
48
tinygrad/viz/lib/graph.js
Normal file
48
tinygrad/viz/lib/graph.js
Normal file
@@ -0,0 +1,48 @@
|
||||
function recenterRects(svg, zoom) {
|
||||
const svgBounds = svg.node().getBoundingClientRect();
|
||||
for (const rect of svg.node().querySelectorAll("rect")) {
|
||||
const rectBounds = rect.getBoundingClientRect();
|
||||
const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
|
||||
rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
|
||||
// if there's at least one rect in view we don't do anything
|
||||
if (!outOfBounds) return;
|
||||
}
|
||||
svg.call(zoom.transform, d3.zoomIdentity)
|
||||
}
|
||||
|
||||
window.renderGraph = function(graph, additions) {
|
||||
// ** initialize the graph
|
||||
const g = new dagreD3.dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
|
||||
for (const [k,v] of Object.entries(graph)) {
|
||||
const [label, src, color] = v;
|
||||
const node = {label, labelType:"text", style:`fill: ${color};`};
|
||||
// for PROGRAM UOp we render the node with a code block
|
||||
if (label.includes("PROGRAM")) {
|
||||
const [name, ...rest] = label.split("\n");
|
||||
const labelEl = Object.assign(document.createElement("div"));
|
||||
labelEl.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }));
|
||||
labelEl.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
|
||||
node.label = labelEl;
|
||||
node.labelType = "html";
|
||||
}
|
||||
g.setNode(k, node);
|
||||
for (const s of src) g.setEdge(s, k, {curve: d3.curveBasis});
|
||||
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
|
||||
}
|
||||
|
||||
// ** select svg render
|
||||
const svg = d3.select("#graph-svg");
|
||||
const inner = svg.select("g");
|
||||
const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", () => {
|
||||
const transform = d3.event.transform;
|
||||
inner.attr("transform", transform);
|
||||
});
|
||||
recenterRects(svg, zoom);
|
||||
svg.call(zoom);
|
||||
|
||||
// ** calculate layout + render
|
||||
const render = new dagreD3.render();
|
||||
render(inner, g);
|
||||
}
|
||||
@@ -124,7 +124,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
|
||||
elif (url:=urlparse(self.path)).path == "/profiler":
|
||||
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
|
||||
elif self.path.startswith("/assets/") and '/..' not in self.path:
|
||||
elif (self.path.startswith("/assets/") or self.path.startswith("/lib/")) and '/..' not in self.path:
|
||||
try:
|
||||
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
|
||||
if url.path.endswith(".js"): content_type = "application/javascript"
|
||||
|
||||
Reference in New Issue
Block a user