From 7988547df2b2c1ea999e996a13dbbb9caa22515e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:13:46 +0300 Subject: [PATCH] start changes from big graph (#6993) * start changes from big graph [pr] * space * still capture ctx --- tinygrad/engine/schedule.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 996d90e839..99eb759a83 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View, strides_for_shape # creation can recurse a lot sys.setrecursionlimit(10000) -BUF_LIMIT = {"METAL": 32} +BUF_LIMIT = {"METAL":32} METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW} # *** ScheduleItem return type *** @@ -45,10 +45,6 @@ class LBScheduleItem: @property def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:] -@dataclass(frozen=True) -class ScheduleItemContext: - bufs: Tuple[int, ...] - # *** UOp with SWIZZLE (movementops) rewriting to UOp we can index *** # ** helpers for doing movementops on uops @@ -124,20 +120,20 @@ reduceop_fusor = PatternMatcher([ (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0])))]) +enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.index(x.arg[0])))]) -PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = [] +PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = [] if getenv("RUN_PROCESS_REPLAY"): @atexit.register def save_process_replay(): for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret)) @track_rewrites -def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp: +def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp: if not AST_REWRITE: return base_sink sink = graph_rewrite(base_sink, reduceop_fusor) - ret = graph_rewrite(sink, enumerate_bufs, ctx) - PROCESS_REPLAY_CAPTURE.append((base_sink, ctx, ret)) + ret = graph_rewrite(sink, enumerate_bufs, bufs) + PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret)) return ret # *** List[LazyBuffer] lowering to ScheduleItem *** @@ -201,7 +197,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl output_st, vv = output_st.simplify().unbind() var_vals.update(vv) ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src))) - sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs))) + sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs)) return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals # *** DAG creation: decide which LazyBuffers should realize ***