mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
late uop_bufs [pr] (#7438)
This commit is contained in:
@@ -42,7 +42,6 @@ class ScheduleItem:
|
||||
class ScheduleContext:
|
||||
realizes: Dict[Buffer, LazyBuffer]
|
||||
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict)
|
||||
uop_bufs: Dict[UOp, Buffer] = field(default_factory=dict)
|
||||
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict)
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict)
|
||||
|
||||
@@ -59,7 +58,6 @@ def to_uop(buf:LazyBuffer, realizes:Dict[UOp, UOp], ctx:ScheduleContext, cache:D
|
||||
# everything else has BUFFER
|
||||
if (b:=buf.buffer) not in ctx.buf_uops:
|
||||
ctx.buf_uops[b] = ubuf = UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype)))
|
||||
ctx.uop_bufs[ubuf] = b
|
||||
if b in ctx.realizes: realizes[ubuf] = ubuf
|
||||
else: ubuf = ctx.buf_uops[b]
|
||||
# if the buffer is already realized we just load it
|
||||
@@ -251,14 +249,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned))
|
||||
|
||||
# do BFS
|
||||
prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=ctx.uop_bufs[u]).size != 0),
|
||||
bufs = list(ctx.buf_uops)
|
||||
prescheduled = [ScheduleItem(u, tuple(b for u in c.bufs if (b:=bufs[u.arg[0]]).size != 0),
|
||||
tuple(m), tuple(c.assign_preloads)) for (u,c),m in zip(small_graphs, metadata)]
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
||||
in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
|
||||
for si in prescheduled:
|
||||
# realize outputs before a parent is assigned to
|
||||
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(ctx.uop_bufs[x])) and xsi is not si)
|
||||
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(bufs[x.arg[0]])) and xsi is not si)
|
||||
for assign in parents_assigns:
|
||||
graph[si].append(assign)
|
||||
in_degree[assign] += 1
|
||||
|
||||
Reference in New Issue
Block a user