mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cachable small graph rewrite (#7371)
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools
|
||||
from typing import Callable, List, Tuple, Union, cast
|
||||
from tinygrad.engine.schedule import ScheduleItemContext, full_ast_rewrite
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.engine.schedule import full_ast_rewrite
|
||||
from tinygrad.codegen.kernel import Kernel, Opt
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.ops import UOp
|
||||
@@ -28,7 +28,7 @@ if REF == "master": SKIP_PROCESS_REPLAY = True
|
||||
|
||||
# *** recreators
|
||||
|
||||
def recreate_sched(sink:UOp, ctx:ScheduleItemContext) -> UOp: return full_ast_rewrite(sink, ctx)
|
||||
def recreate_sched(*args) -> UOp: return full_ast_rewrite(*args[0], ubuf_metadata={})[0]
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, _) -> str:
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
|
||||
@@ -201,7 +201,9 @@ multioutput = PatternMatcher([
|
||||
(UPat.load(UPat.var("b"), UPat()), lambda stores,b: stores.get(b)),
|
||||
])
|
||||
|
||||
def full_ast_rewrite(pre:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp], ubuf_metadata:Dict[UOp, Metadata]) -> Tuple[UOp, ScheduleItemContext]:
|
||||
metadata = {mx:None for x in pre.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ubuf_metadata.get(x.src[0]))}
|
||||
ctx = ScheduleItemContext(var_vals, assigned, metadata=metadata)
|
||||
# fuse and fold store -> loads
|
||||
sink = graph_rewrite(pre, lazy)
|
||||
# fuse multi output
|
||||
@@ -220,14 +222,14 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append((pre, ScheduleItemContext(ctx.var_vals, ctx.assigned), sink))
|
||||
return sink
|
||||
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
|
||||
return sink, ctx
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
@atexit.register
|
||||
def save_process_replay():
|
||||
for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, {}, ret))
|
||||
for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
@@ -240,11 +242,8 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
for stores in store_groups:
|
||||
outs = [lazybufs_to_realize[b] for b in stores]
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
to_store = tuple(to_uop(out, outs, ctx, cache) for out in outs)
|
||||
sink = UOp(UOps.SINK, src=tuple(UOp.store(ctx.buf_uops[x.buffer], ShapeTracker.from_shape(x.shape).to_uop(), u) for x,u in zip(outs,to_store)))
|
||||
metadata = {mx:None for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))}
|
||||
si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}, metadata=metadata)
|
||||
small_graphs.append((full_ast_rewrite(sink, si_ctx), si_ctx))
|
||||
small_graphs.append(full_ast_rewrite(UOp.sink(*(to_uop(out, outs, ctx, cache).src[2] for out in outs)),
|
||||
ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}, ctx.ubuf_metadata))
|
||||
|
||||
# do BFS
|
||||
prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=ctx.uop_bufs[u]).size != 0),
|
||||
|
||||
Reference in New Issue
Block a user