diff --git a/viz/index.html b/viz/index.html
index dbe9e22f77..0829d406c3 100644
--- a/viz/index.html
+++ b/viz/index.html
@@ -135,25 +135,27 @@
totalUOps = rest.length-1;
totalRewrites = ctx.graphs.length-1;
// graph
- const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
- const graph = ctx.graphs[currentRewrite];
- for ([k,u] of Object.entries(graph)) {
- g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
- for (src of u[2]) {
- g.setEdge(src, k)
+ function renderGraph(graph) {
+ const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
+ for ([k,u] of Object.entries(graph)) {
+ g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
+ for (src of u[2]) {
+ g.setEdge(src, k)
+ }
}
+ const svg = d3.select("svg");
+ const inner = svg.select("g");
+ var zoom = d3.zoom()
+ .scaleExtent([0.05, 2])
+ .on("zoom", () => {
+ const transform = d3.event.transform;
+ inner.attr("transform", transform);
+ });
+ svg.call(zoom);
+ const render = new dagreD3.render();
+ render(inner, g);
}
- const svg = d3.select("svg");
- const inner = svg.select("g");
- var zoom = d3.zoom()
- .scaleExtent([0.05, 2])
- .on("zoom", () => {
- const transform = d3.event.transform;
- inner.attr("transform", transform);
- });
- svg.call(zoom);
- const render = new dagreD3.render();
- render(inner, g);
+ renderGraph(ctx.graphs[currentRewrite][0])
// metadata
const container = document.querySelector(".container.metadata");
container.innerHTML = "";
@@ -166,7 +168,7 @@
if (ctx.graphs.length > 1) {
const rewriteCounter = Object.assign(document.createElement("div"), { className: "rewrite-counter" });
container.appendChild(rewriteCounter)
- ctx.graphs.forEach((g, i) => {
+ ctx.graphs.forEach((_, i) => {
const rewriteDiv = Object.assign(document.createElement("div"), { textContent: i, className: "uop-el" });
if (i === currentRewrite) {
rewriteDiv.classList.add("active-uop-el");
diff --git a/viz/serve.py b/viz/serve.py
index bfcb5f6367..e1e91243a8 100755
--- a/viz/serve.py
+++ b/viz/serve.py
@@ -29,10 +29,10 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
@dataclass(frozen=True)
class UOpRet:
- loc: str # location that called graph_rewrite
- uops: List[UOp] # snapshot of the entire AST after each rewrite
- diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
- extra: List[List[str]] # these become code blocks in the UI
+ loc: str
+ graphs: List[Tuple[UOp, UOp, UOp, UOp]] # snapshot of the entire AST after each rewrite
+ diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
+ extra: List[List[str]] # these become code blocks in the UI
def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp:
if (found:=cache.get(base.key)): return found
@@ -45,6 +45,7 @@ def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp:
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
uops: List[UOp] = [ctx.sink]
+ graphs: List[Tuple[UOp, UOp, UOp, UOp]] = [(ctx.sink, ctx.sink, ctx.sink, ctx.sink)]
diffs: List[Tuple[str, List[str]]] = []
extra: List[List[str]] = [[str(ctx.sink)]]
for (first, rewritten, pattern) in ctx.rewrites:
@@ -53,10 +54,11 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
# TODO: sometimes it hits a ctx and can't find any UOp to replace
#if new_sink is uops[-1]: continue
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
- assert new_sink.op is UOps.SINK
+ assert new_sink.op is uops[-1].op
+ graphs.append((new_sink, uops[-1], rewritten, first))
uops.append(new_sink)
extra.append([str(new_sink)])
- return UOpRet(ctx.loc, uops, diffs, extra)
+ return UOpRet(ctx.loc, graphs, diffs, extra)
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
@@ -79,7 +81,8 @@ class Handler(BaseHTTPRequestHandler):
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
rest = [x.loc for x in contexts]
g = create_graph(contexts[int(self.path.split("/")[-1])])
- ret = json.dumps(({"loc": g.loc, "graphs": list(map(uop_to_json, g.uops)), "diffs": g.diffs, "extra": g.extra}, rest)).encode()
+ ret = json.dumps(({"loc": g.loc, "graphs": [[uop_to_json(x) for x in graph] for graph in g.graphs],
+ "diffs": g.diffs, "extra": g.extra}, rest)).encode()
else:
self.send_response(404)
ret = b""