mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
start changes from big graph (#6993)
* start changes from big graph [pr] * space * still capture ctx
This commit is contained in:
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View, strides_for_shape
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
BUF_LIMIT = {"METAL": 32}
|
||||
BUF_LIMIT = {"METAL":32}
|
||||
METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW}
|
||||
|
||||
# *** ScheduleItem return type ***
|
||||
@@ -45,10 +45,6 @@ class LBScheduleItem:
|
||||
@property
|
||||
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
bufs: Tuple[int, ...]
|
||||
|
||||
# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index ***
|
||||
|
||||
# ** helpers for doing movementops on uops
|
||||
@@ -124,20 +120,20 @@ reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0])))])
|
||||
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.index(x.arg[0])))])
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = []
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
@atexit.register
|
||||
def save_process_replay():
|
||||
for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))
|
||||
|
||||
@track_rewrites
|
||||
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp:
|
||||
if not AST_REWRITE: return base_sink
|
||||
sink = graph_rewrite(base_sink, reduceop_fusor)
|
||||
ret = graph_rewrite(sink, enumerate_bufs, ctx)
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, ctx, ret))
|
||||
ret = graph_rewrite(sink, enumerate_bufs, bufs)
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
|
||||
return ret
|
||||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
@@ -201,7 +197,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
var_vals.update(vv)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
Reference in New Issue
Block a user