late uop_bufs [pr] (#7438)

This commit is contained in:
qazal
2024-10-31 11:30:32 +02:00
committed by GitHub
parent 7916d1f6ab
commit a0bd385448

View File

@@ -42,7 +42,6 @@ class ScheduleItem:
class ScheduleContext:
realizes: Dict[Buffer, LazyBuffer]
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict)
uop_bufs: Dict[UOp, Buffer] = field(default_factory=dict)
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict)
var_vals: Dict[Variable, int] = field(default_factory=dict)
@@ -59,7 +58,6 @@ def to_uop(buf:LazyBuffer, realizes:Dict[UOp, UOp], ctx:ScheduleContext, cache:D
# everything else has BUFFER
if (b:=buf.buffer) not in ctx.buf_uops:
ctx.buf_uops[b] = ubuf = UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype)))
ctx.uop_bufs[ubuf] = b
if b in ctx.realizes: realizes[ubuf] = ubuf
else: ubuf = ctx.buf_uops[b]
# if the buffer is already realized we just load it
@@ -251,14 +249,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned))
# do BFS
prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=ctx.uop_bufs[u]).size != 0),
bufs = list(ctx.buf_uops)
prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=bufs[u.arg[0]]).size != 0),
tuple(m), tuple(c.assign_preloads)) for (u,c),m in zip(small_graphs, metadata)]
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
for si in prescheduled:
# realize outputs before a parent is assigned to
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(ctx.uop_bufs[x])) and xsi is not si)
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(bufs[x.arg[0]])) and xsi is not si)
for assign in parents_assigns:
graph[si].append(assign)
in_degree[assign] += 1