mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
viz UOp layout cleanup (#13787)
* use the same names in server and client * first layout args, then renderer args
This commit is contained in:
@@ -109,12 +109,12 @@ async function initWorker() {
|
||||
workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" }));
|
||||
}
|
||||
|
||||
function renderDag(graph, additions, recenter, layoutOpts) {
|
||||
function renderDag(layoutSpec, { recenter }) {
|
||||
// start calculating the new layout (non-blocking)
|
||||
updateProgress(Status.STARTED, "Rendering new graph...");
|
||||
if (worker != null) worker.terminate();
|
||||
worker = new Worker(workerUrl);
|
||||
worker.postMessage({graph, additions, opts:layoutOpts });
|
||||
worker.postMessage(layoutSpec);
|
||||
worker.onmessage = (e) => {
|
||||
displaySelection("#graph");
|
||||
updateProgress(Status.COMPLETE);
|
||||
@@ -837,7 +837,7 @@ async function main() {
|
||||
if (ret.length === 0) return;
|
||||
// ** center graph
|
||||
const data = ret[currentRewrite];
|
||||
const render = (opts) => renderDag(data.graph, data.changed_nodes ?? [], currentRewrite === 0, opts);
|
||||
const render = (opts) => renderDag({ graph:data.graph, change:data.change, opts }, { recenter:currentRewrite === 0 });
|
||||
render({ showIndexing:toggle.checked });
|
||||
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
|
||||
// ** right sidebar metadata
|
||||
|
||||
@@ -5,11 +5,11 @@ const ctx = canvas.getContext("2d");
|
||||
ctx.font = `350 ${LINE_HEIGHT}px sans-serif`;
|
||||
|
||||
onmessage = (e) => {
|
||||
const { graph, additions, opts } = e.data;
|
||||
const { graph, change, opts } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
if (additions.length !== 0) g.setNode("addition", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
|
||||
for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
|
||||
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
|
||||
for (const [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
|
||||
// adjust node dims by label size (excluding escape codes) + add padding
|
||||
let [width, height] = [0, 0];
|
||||
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
|
||||
@@ -18,10 +18,10 @@ onmessage = (e) => {
|
||||
}
|
||||
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, label, labelHeight:height, labelWidth:width, ref, id:k, ...rest});
|
||||
// add edges
|
||||
const edgeCounts = {}
|
||||
const edgeCounts = {};
|
||||
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
|
||||
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
|
||||
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
|
||||
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
|
||||
}
|
||||
// optionally hide nodes from the layuot
|
||||
if (!opts.showIndexing) {
|
||||
@@ -31,8 +31,8 @@ onmessage = (e) => {
|
||||
}
|
||||
}
|
||||
dagre.layout(g);
|
||||
// remove additions overlay if it's empty
|
||||
if (!g.node("addition")?.width) g.removeNode("addition");
|
||||
// remove overlay node if it's empty
|
||||
if (!g.node("overlay")?.width) g.removeNode("overlay");
|
||||
postMessage(dagre.graphlib.json.write(g));
|
||||
self.close();
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ class GraphRewriteDetails(TypedDict):
|
||||
graph: dict # JSON serialized UOp for this rewrite step
|
||||
uop: str # strigified UOp for this rewrite step
|
||||
diff: list[str]|None # diff of the single UOp that changed
|
||||
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
||||
change: list[int]|None # the new UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
||||
|
||||
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
|
||||
@@ -115,14 +115,14 @@ def _reconstruct(a:int):
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(ctx.sink)
|
||||
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "changed_nodes":None, "diff":None, "upat":None}
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
|
||||
Reference in New Issue
Block a user