viz UOp layout cleanup (#13787)

* use the same names in server and client

* first layout args, then renderer args
This commit is contained in:
qazal
2025-12-21 23:11:40 +09:00
committed by GitHub
parent e523971028
commit 9839838fdd
3 changed files with 13 additions and 13 deletions

View File

@@ -109,12 +109,12 @@ async function initWorker() {
workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" }));
}
function renderDag(graph, additions, recenter, layoutOpts) {
function renderDag(layoutSpec, { recenter }) {
// start calculating the new layout (non-blocking)
updateProgress(Status.STARTED, "Rendering new graph...");
if (worker != null) worker.terminate();
worker = new Worker(workerUrl);
worker.postMessage({graph, additions, opts:layoutOpts });
worker.postMessage(layoutSpec);
worker.onmessage = (e) => {
displaySelection("#graph");
updateProgress(Status.COMPLETE);
@@ -837,7 +837,7 @@ async function main() {
if (ret.length === 0) return;
// ** center graph
const data = ret[currentRewrite];
const render = (opts) => renderDag(data.graph, data.changed_nodes ?? [], currentRewrite === 0, opts);
const render = (opts) => renderDag({ graph:data.graph, change:data.change, opts }, { recenter:currentRewrite === 0 });
render({ showIndexing:toggle.checked });
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
// ** right sidebar metadata

View File

@@ -5,11 +5,11 @@ const ctx = canvas.getContext("2d");
ctx.font = `350 ${LINE_HEIGHT}px sans-serif`;
onmessage = (e) => {
const { graph, additions, opts } = e.data;
const { graph, change, opts } = e.data;
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
if (additions.length !== 0) g.setNode("addition", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
for (const [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
// adjust node dims by label size (excluding escape codes) + add padding
let [width, height] = [0, 0];
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
@@ -18,10 +18,10 @@ onmessage = (e) => {
}
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, label, labelHeight:height, labelWidth:width, ref, id:k, ...rest});
// add edges
const edgeCounts = {}
const edgeCounts = {};
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
}
// optionally hide nodes from the layuot
if (!opts.showIndexing) {
@@ -31,8 +31,8 @@ onmessage = (e) => {
}
}
dagre.layout(g);
// remove additions overlay if it's empty
if (!g.node("addition")?.width) g.removeNode("addition");
// remove overlay node if it's empty
if (!g.node("overlay")?.width) g.removeNode("overlay");
postMessage(dagre.graphlib.json.write(g));
self.close();
}

View File

@@ -53,7 +53,7 @@ class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # diff of the single UOp that changed
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
change: list[int]|None # the new UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
@@ -115,14 +115,14 @@ def _reconstruct(a:int):
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(ctx.sink)
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "changed_nodes":None, "diff":None, "upat":None}
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink