diff --git a/tinygrad/viz/lib/worker.js b/tinygrad/viz/lib/worker.js index c4211776e6..ba3e7d1468 100644 --- a/tinygrad/viz/lib/worker.js +++ b/tinygrad/viz/lib/worker.js @@ -21,7 +21,7 @@ onmessage = (e) => { const g = new dagre.graphlib.Graph({ compound: true }); g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); if (additions.length !== 0) g.setNode("addition", {label: "", style: "fill: rgba(26, 27, 38, 0.5); stroke: none;", padding:0}); - for (const [k, [label, src, color]] of Object.entries(graph)) { + for (const [k, {label, src, color}] of Object.entries(graph)) { // adjust node dims by label size + add padding const [labelWidth, labelHeight] = getTextDims(label); g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING}); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 0632da7e62..2986eed2c8 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -49,10 +49,9 @@ class GraphRewriteDetails(TypedDict): changed_nodes: list[int]|None # the changed UOp id + all its parents ids upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat -def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]: +def uop_to_json(x:UOp) -> dict[int, dict]: assert isinstance(x, UOp) - # NOTE: this is [id, [label, src_ids, color]] - graph: dict[int, tuple[str, list[int], str]] = {} + graph: dict[int, dict] = {} excluded: set[UOp] = set() for u in (toposort:=x.toposort): # always exclude DEVICE/CONST/UNIQUE @@ -72,7 +71,7 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]: if x in excluded: if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}" else: label += f"\n{x.op.name}{idx} {x.arg}" - graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff")) + graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")} return graph def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: