viz server cleanups (#13668)

* viz server cleanups

* comment
This commit is contained in:
qazal
2025-12-13 04:27:53 -05:00
committed by GitHub
parent f6cc3b13b9
commit a6dfd8a672

View File

@@ -58,7 +58,8 @@ class GraphRewriteDetails(TypedDict):
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
def pystr(u:UOp, i:int) -> str:
def pystr(u:UOp) -> str:
# pyrender may check for shape mismatch
try: return pyrender(u)
except Exception: return str(u)
@@ -111,19 +112,18 @@ def _reconstruct(a:int):
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(ctx.sink)
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink,i),
"changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0,i).splitlines(),pystr(u1,i).splitlines())), "upat":(upat_loc, match_repr)}
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink
# encoder helpers
@@ -445,7 +445,7 @@ class Handler(HTTPRequestHandler):
elif (query:=parse_qs(url.query)):
i, j = get_int(query, "ctx"), get_int(query, "step")
if (fmt:=url.path.lstrip("/")) == "rewrites":
try: return self.stream_json(get_full_rewrite(trace.rewrites[i][j], i))
try: return self.stream_json(get_full_rewrite(trace.rewrites[i][j]))
except (KeyError, IndexError): status_code = 404
else:
render_src = get_render(i, j, fmt)