mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
refactor uop_to_json to return a dict [pr] (#9560)
This commit is contained in:
@@ -21,7 +21,7 @@ onmessage = (e) => {
|
||||
const g = new dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
if (additions.length !== 0) g.setNode("addition", {label: "", style: "fill: rgba(26, 27, 38, 0.5); stroke: none;", padding:0});
|
||||
for (const [k, [label, src, color]] of Object.entries(graph)) {
|
||||
for (const [k, {label, src, color}] of Object.entries(graph)) {
|
||||
// adjust node dims by label size + add padding
|
||||
const [labelWidth, labelHeight] = getTextDims(label);
|
||||
g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING});
|
||||
|
||||
@@ -49,10 +49,9 @@ class GraphRewriteDetails(TypedDict):
|
||||
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
||||
|
||||
def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
|
||||
def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
assert isinstance(x, UOp)
|
||||
# NOTE: this is [id, [label, src_ids, color]]
|
||||
graph: dict[int, tuple[str, list[int], str]] = {}
|
||||
graph: dict[int, dict] = {}
|
||||
excluded: set[UOp] = set()
|
||||
for u in (toposort:=x.toposort):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
@@ -72,7 +71,7 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
|
||||
if x in excluded:
|
||||
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
|
||||
else: label += f"\n{x.op.name}{idx} {x.arg}"
|
||||
graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff"))
|
||||
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}
|
||||
return graph
|
||||
|
||||
def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
|
||||
Reference in New Issue
Block a user