diff --git a/viz/index.html b/viz/index.html index dbe9e22f77..0829d406c3 100644 --- a/viz/index.html +++ b/viz/index.html @@ -135,25 +135,27 @@ totalUOps = rest.length-1; totalRewrites = ctx.graphs.length-1; // graph - const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); - const graph = ctx.graphs[currentRewrite]; - for ([k,u] of Object.entries(graph)) { - g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` }); - for (src of u[2]) { - g.setEdge(src, k) + function renderGraph(graph) { + const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); + for ([k,u] of Object.entries(graph)) { + g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` }); + for (src of u[2]) { + g.setEdge(src, k) + } } + const svg = d3.select("svg"); + const inner = svg.select("g"); + var zoom = d3.zoom() + .scaleExtent([0.05, 2]) + .on("zoom", () => { + const transform = d3.event.transform; + inner.attr("transform", transform); + }); + svg.call(zoom); + const render = new dagreD3.render(); + render(inner, g); } - const svg = d3.select("svg"); - const inner = svg.select("g"); - var zoom = d3.zoom() - .scaleExtent([0.05, 2]) - .on("zoom", () => { - const transform = d3.event.transform; - inner.attr("transform", transform); - }); - svg.call(zoom); - const render = new dagreD3.render(); - render(inner, g); + renderGraph(ctx.graphs[currentRewrite][0]) // metadata const container = document.querySelector(".container.metadata"); container.innerHTML = ""; @@ -166,7 +168,7 @@ if (ctx.graphs.length > 1) { const rewriteCounter = Object.assign(document.createElement("div"), { className: "rewrite-counter" }); container.appendChild(rewriteCounter) - ctx.graphs.forEach((g, i) => { + ctx.graphs.forEach((_, i) => { const rewriteDiv = Object.assign(document.createElement("div"), { textContent: i, className: "uop-el" }); if (i === currentRewrite) { rewriteDiv.classList.add("active-uop-el"); diff --git a/viz/serve.py b/viz/serve.py index bfcb5f6367..e1e91243a8 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -29,10 +29,10 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: @dataclass(frozen=True) class UOpRet: - loc: str # location that called graph_rewrite - uops: List[UOp] # snapshot of the entire AST after each rewrite - diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite - extra: List[List[str]] # these become code blocks in the UI + loc: str + graphs: List[Tuple[UOp, UOp, UOp, UOp]] # snapshot of the entire AST after each rewrite + diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite + extra: List[List[str]] # these become code blocks in the UI def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp: if (found:=cache.get(base.key)): return found @@ -45,6 +45,7 @@ def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp: def create_graph(ctx:TrackedRewriteContext) -> UOpRet: uops: List[UOp] = [ctx.sink] + graphs: List[Tuple[UOp, UOp, UOp, UOp]] = [(ctx.sink, ctx.sink, ctx.sink, ctx.sink)] diffs: List[Tuple[str, List[str]]] = [] extra: List[List[str]] = [[str(ctx.sink)]] for (first, rewritten, pattern) in ctx.rewrites: @@ -53,10 +54,11 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet: # TODO: sometimes it hits a ctx and can't find any UOp to replace #if new_sink is uops[-1]: continue diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))) - assert new_sink.op is UOps.SINK + assert new_sink.op is uops[-1].op + graphs.append((new_sink, uops[-1], rewritten, first)) uops.append(new_sink) extra.append([str(new_sink)]) - return UOpRet(ctx.loc, uops, diffs, extra) + return UOpRet(ctx.loc, graphs, diffs, extra) class Handler(BaseHTTPRequestHandler): def do_GET(self): @@ -79,7 +81,8 @@ class Handler(BaseHTTPRequestHandler): with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f) rest = [x.loc for x in contexts] g = create_graph(contexts[int(self.path.split("/")[-1])]) - ret = json.dumps(({"loc": g.loc, "graphs": list(map(uop_to_json, g.uops)), "diffs": g.diffs, "extra": g.extra}, rest)).encode() + ret = json.dumps(({"loc": g.loc, "graphs": [[uop_to_json(x) for x in graph] for graph in g.graphs], + "diffs": g.diffs, "extra": g.extra}, rest)).encode() else: self.send_response(404) ret = b""