From a0bd38544871a33e24f16fe538a3cdc2a99d33d0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:30:32 +0200 Subject: [PATCH] late uop_bufs [pr] (#7438) --- tinygrad/engine/schedule.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f991a80ca9..10afc2c08f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -42,7 +42,6 @@ class ScheduleItem: class ScheduleContext: realizes: Dict[Buffer, LazyBuffer] buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) - uop_bufs: Dict[UOp, Buffer] = field(default_factory=dict) ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) var_vals: Dict[Variable, int] = field(default_factory=dict) @@ -59,7 +58,6 @@ def to_uop(buf:LazyBuffer, realizes:Dict[UOp, UOp], ctx:ScheduleContext, cache:D # everything else has BUFFER if (b:=buf.buffer) not in ctx.buf_uops: ctx.buf_uops[b] = ubuf = UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))) - ctx.uop_bufs[ubuf] = b if b in ctx.realizes: realizes[ubuf] = ubuf else: ubuf = ctx.buf_uops[b] # if the buffer is already realized we just load it @@ -251,14 +249,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned)) # do BFS - prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=ctx.uop_bufs[u]).size != 0), + bufs = list(ctx.buf_uops) + prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=bufs[u.arg[0]]).size != 0), tuple(m), tuple(c.assign_preloads)) for (u,c),m in zip(small_graphs, metadata)] schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list) in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int) for si in prescheduled: # realize outputs before a parent is assigned to - parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(ctx.uop_bufs[x])) and xsi is not si) + parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(bufs[x.arg[0]])) and xsi is not si) for assign in parents_assigns: graph[si].append(assign) in_degree[assign] += 1