mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -58,7 +58,8 @@ class GraphRewriteDetails(TypedDict):
|
||||
|
||||
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
|
||||
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
|
||||
def pystr(u:UOp, i:int) -> str:
|
||||
def pystr(u:UOp) -> str:
|
||||
# pyrender may check for shape mismatch
|
||||
try: return pyrender(u)
|
||||
except Exception: return str(u)
|
||||
|
||||
@@ -111,19 +112,18 @@ def _reconstruct(a:int):
|
||||
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
|
||||
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
|
||||
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(ctx.sink)
|
||||
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "changed_nodes":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink,i),
|
||||
"changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pystr(u0,i).splitlines(),pystr(u1,i).splitlines())), "upat":(upat_loc, match_repr)}
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
# encoder helpers
|
||||
@@ -445,7 +445,7 @@ class Handler(HTTPRequestHandler):
|
||||
elif (query:=parse_qs(url.query)):
|
||||
i, j = get_int(query, "ctx"), get_int(query, "step")
|
||||
if (fmt:=url.path.lstrip("/")) == "rewrites":
|
||||
try: return self.stream_json(get_full_rewrite(trace.rewrites[i][j], i))
|
||||
try: return self.stream_json(get_full_rewrite(trace.rewrites[i][j]))
|
||||
except (KeyError, IndexError): status_code = 404
|
||||
else:
|
||||
render_src = get_render(i, j, fmt)
|
||||
|
||||
Reference in New Issue
Block a user