mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add support for named rewrites [pr] (#9152)
This commit is contained in:
@@ -407,6 +407,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
# remove_movement_ops + sym
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
||||
# do_realize + group_realizes
|
||||
buffer_map: dict[UOp, UOp] = {}
|
||||
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
|
||||
@@ -452,8 +455,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep: sched_sink = sched_sink.substitute(assign_rep)
|
||||
|
||||
# display the final graph
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
# final toposort (bfs)
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
|
||||
@@ -853,6 +853,7 @@ class TrackedGraphRewrite:
|
||||
sink: UOp # the sink input to graph_rewrite
|
||||
bottom_up: bool
|
||||
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
||||
name: Optional[str] = None
|
||||
tracked_keys:list[Any] = []
|
||||
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
||||
_name_cnt:dict[str, int] = {}
|
||||
@@ -935,14 +936,14 @@ class RewriteContext:
|
||||
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
||||
return ret
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
||||
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
|
||||
|
||||
@@ -327,7 +327,8 @@
|
||||
const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerHTML: coloredToHTML(key), style: "cursor: pointer;"});
|
||||
kernelUl.appendChild(p)
|
||||
items.forEach((u, j) => {
|
||||
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
|
||||
const rwUl = Object.assign(document.createElement("ul"), {
|
||||
innerText: u.name ? `${u.name} - ${u.match_count}` : `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
|
||||
className: (j === currentUOp && i == currentKernel) ? "active" : "" })
|
||||
if (j === currentUOp) {
|
||||
requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
|
||||
|
||||
@@ -30,12 +30,13 @@ class GraphRewriteMetadata(TypedDict):
|
||||
match_count: int # total match count in this context
|
||||
code_line: str # source code calling graph_rewrite
|
||||
kernel_code: str|None # optionally render the final kernel code
|
||||
name: str|None # optional name of the rewrite
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _prg(k:Kernel): return k.to_program().src
|
||||
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
||||
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
||||
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None}
|
||||
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "name":v.name}
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
||||
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user