refactor uop_to_json to return a dict [pr] (#9560)

This commit is contained in:
qazal
2025-03-24 16:38:17 +08:00
committed by GitHub
parent edf9e1bf8d
commit 1cfe6d02fe
2 changed files with 4 additions and 5 deletions

View File

@@ -21,7 +21,7 @@ onmessage = (e) => {
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
if (additions.length !== 0) g.setNode("addition", {label: "", style: "fill: rgba(26, 27, 38, 0.5); stroke: none;", padding:0});
for (const [k, [label, src, color]] of Object.entries(graph)) {
for (const [k, {label, src, color}] of Object.entries(graph)) {
// adjust node dims by label size + add padding
const [labelWidth, labelHeight] = getTextDims(label);
g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING});

View File

@@ -49,10 +49,9 @@ class GraphRewriteDetails(TypedDict):
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
def uop_to_json(x:UOp) -> dict[int, dict]:
assert isinstance(x, UOp)
# NOTE: this is [id, [label, src_ids, color]]
graph: dict[int, tuple[str, list[int], str]] = {}
graph: dict[int, dict] = {}
excluded: set[UOp] = set()
for u in (toposort:=x.toposort):
# always exclude DEVICE/CONST/UNIQUE
@@ -72,7 +71,7 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
if x in excluded:
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
else: label += f"\n{x.op.name}{idx} {x.arg}"
graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff"))
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}
return graph
def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: