start changes from big graph (#6993)

* start changes from big graph [pr]

* space

* still capture ctx
This commit is contained in:
qazal
2024-10-11 11:13:46 +03:00
committed by GitHub
parent e7a0ffe46a
commit 7988547df2

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View, strides_for_shape
# creation can recurse a lot
sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL": 32}
BUF_LIMIT = {"METAL":32}
METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW}
# *** ScheduleItem return type ***
@@ -45,10 +45,6 @@ class LBScheduleItem:
@property
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
@dataclass(frozen=True)
class ScheduleItemContext:
bufs: Tuple[int, ...]
# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index ***
# ** helpers for doing movementops on uops
@@ -124,20 +120,20 @@ reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0])))])
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.index(x.arg[0])))])
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = []
if getenv("RUN_PROCESS_REPLAY"):
@atexit.register
def save_process_replay():
for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))
@track_rewrites
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp:
if not AST_REWRITE: return base_sink
sink = graph_rewrite(base_sink, reduceop_fusor)
ret = graph_rewrite(sink, enumerate_bufs, ctx)
PROCESS_REPLAY_CAPTURE.append((base_sink, ctx, ret))
ret = graph_rewrite(sink, enumerate_bufs, bufs)
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
return ret
# *** List[LazyBuffer] lowering to ScheduleItem ***
@@ -201,7 +197,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl
output_st, vv = output_st.simplify().unbind()
var_vals.update(vv)
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
# *** DAG creation: decide which LazyBuffers should realize ***