add support for named rewrites [pr] (#9152)

This commit is contained in:
George Hotz
2025-02-18 16:07:04 +08:00
committed by GitHub
parent caee42e8a6
commit 6d62966bf7
4 changed files with 14 additions and 7 deletions

View File

@@ -407,6 +407,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# remove_movement_ops + sym
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
# do_realize + group_realizes
buffer_map: dict[UOp, UOp] = {}
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
@@ -452,8 +455,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: sched_sink = sched_sink.substitute(assign_rep)
# display the final graph
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
# final toposort (bfs)
children: dict[UOp, list[UOp]] = {}

View File

@@ -853,6 +853,7 @@ class TrackedGraphRewrite:
sink: UOp # the sink input to graph_rewrite
bottom_up: bool
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
name: Optional[str] = None
tracked_keys:list[Any] = []
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
_name_cnt:dict[str, int] = {}
@@ -935,14 +936,14 @@ class RewriteContext:
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
rewrite_ctx = RewriteContext(pm, ctx)
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}

View File

@@ -327,7 +327,8 @@
const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerHTML: coloredToHTML(key), style: "cursor: pointer;"});
kernelUl.appendChild(p)
items.forEach((u, j) => {
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
const rwUl = Object.assign(document.createElement("ul"), {
innerText: u.name ? `${u.name} - ${u.match_count}` : `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
className: (j === currentUOp && i == currentKernel) ? "active" : "" })
if (j === currentUOp) {
requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" }));

View File

@@ -30,12 +30,13 @@ class GraphRewriteMetadata(TypedDict):
match_count: int # total match count in this context
code_line: str # source code calling graph_rewrite
kernel_code: str|None # optionally render the final kernel code
name: str|None # optional name of the rewrite
@functools.lru_cache(None)
def _prg(k:Kernel): return k.to_program().src
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None}
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "name":v.name}
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]