mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
assign toposort with big graph, bfs [pr] (#7242)
* assign toposort with big graph, bfs [pr] * cycle * merge 2 * filter bufs * delete inputs
This commit is contained in:
@@ -48,4 +48,4 @@ def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
|
||||
|
||||
@@ -25,6 +25,7 @@ class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: Tuple[Buffer, ...]
|
||||
metadata: Tuple[Metadata, ...]
|
||||
assign_preloads: Tuple[UOp, ...]
|
||||
@property
|
||||
def outputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
@@ -36,17 +37,6 @@ class ScheduleItem:
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LBScheduleItem:
|
||||
ast: UOp
|
||||
bufs: Tuple[LazyBuffer, ...]
|
||||
ubufs: Tuple[UOp, ...]
|
||||
metadata: Tuple[Metadata, ...]
|
||||
@property
|
||||
def outputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
|
||||
@property
|
||||
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
|
||||
|
||||
# *** UOp with VIEW (movementops) rewriting to UOp we can index ***
|
||||
|
||||
# ** helpers for doing movementops on uops
|
||||
@@ -132,8 +122,10 @@ view_right = merge_views+PatternMatcher([
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
var_vals: Dict[Variable, int]
|
||||
assigned: Set[UOp]
|
||||
sts: Set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: List[UOp] = field(default_factory=list)
|
||||
assign_preloads: List[UOp] = field(default_factory=list)
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
@@ -145,11 +137,15 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
if b in ctx.assigned: ctx.assign_preloads.append(b)
|
||||
return x.replace(op=UOps.LOAD)
|
||||
|
||||
to_si = PatternMatcher([
|
||||
(UPat(UOps.BUFFER, name="x"), _append_buf),
|
||||
(UPat(UOps.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(UOps.PRELOAD, name="x"), lambda _,x: x.replace(op=UOps.LOAD)),
|
||||
(UPat(UOps.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
|
||||
(UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda _,x: x),
|
||||
(UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda _,x: x),
|
||||
])
|
||||
@@ -157,8 +153,8 @@ to_si = PatternMatcher([
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
|
||||
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
|
||||
ret = graph_rewrite(sink, to_si, ctx)
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals), ret))
|
||||
ret = graph_rewrite(graph_rewrite(sink, to_si, ctx), append_bufs, ctx)
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals, ctx.assigned), ret))
|
||||
return ret
|
||||
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
@@ -168,18 +164,16 @@ if getenv("RUN_PROCESS_REPLAY"):
|
||||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata],
|
||||
cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, metadata, cache).view(buf.st)
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, buf_uops, metadata, cache).view(buf.st)
|
||||
return ret
|
||||
if buf.op is MetaOps.CONST: return buf_uops[buf.buffer]
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
|
||||
if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf)
|
||||
return UOp(UOps.PRELOAD if buf.is_realized() else UOps.LOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
src = tuple(to_uop(x, outputs, inputs, buf_uops, metadata, cache) for x in buf.srcs)
|
||||
src = tuple(to_uop(x, outputs, buf_uops, metadata, cache) for x in buf.srcs)
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg)
|
||||
@@ -191,26 +185,25 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu
|
||||
if buf.metadata is not None: metadata[ret] = buf.metadata
|
||||
return ret
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], uop_bufs:Dict[UOp, Buffer],
|
||||
var_vals:Dict[Variable, int]) -> LBScheduleItem:
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], uop_bufs:Dict[UOp, Buffer], assigned:Set[UOp],
|
||||
var_vals:Dict[Variable, int]) -> ScheduleItem:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
inputs: List[LazyBuffer] = []
|
||||
metadata: Dict[UOp, Metadata] = {}
|
||||
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
|
||||
to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs))
|
||||
to_uop(out, outs, buf_uops, metadata, cache)) for out in outs))
|
||||
# assert cyclic dependency
|
||||
for b,reads in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD, UOps.LOAD}), key=lambda x:x.src[0]):
|
||||
for b,reads in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD, UOps.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
|
||||
if not all_same([x.op for x in reads]):
|
||||
raise RuntimeError(f"cycle detected in kernel.\nhelp: consider using .contiguous() to load the pre-assign version of {uop_bufs[b]}.")
|
||||
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals))
|
||||
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals, assigned))
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(ctx.bufs), tuple(dedup(metadata.values())))
|
||||
return ScheduleItem(sink, tuple(b for u in ctx.bufs if (b:=uop_bufs[u]).size != 0), tuple(dedup(metadata.values())), tuple(ctx.assign_preloads))
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
@@ -372,7 +365,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
buf_uops: Dict[Buffer, UOp] = {}
|
||||
uop_bufs: Dict[UOp, Buffer] = {}
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
assigned: Set[UOp] = set()
|
||||
lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {}
|
||||
for buf in realizes:
|
||||
if buf.realized is None and buf.op is not MetaOps.CONST:
|
||||
@@ -399,24 +392,24 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
buf_uops[buf.buffer] = uop
|
||||
uop_bufs[uop] = buf.buffer
|
||||
if buf.realized is None:
|
||||
if buf.op is MetaOps.ASSIGN: assign_targets[buf.srcs[0]] = buf
|
||||
if buf.op is MetaOps.ASSIGN: assigned.add(buf_uops[buf.buffer])
|
||||
if buf.op is not MetaOps.CONST:output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
|
||||
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, uop_bufs,
|
||||
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, uop_bufs, assigned,
|
||||
var_vals) for outs in output_groups.values()]
|
||||
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
||||
|
||||
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
|
||||
in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int)
|
||||
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
||||
in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
|
||||
for lsi in prescheduled:
|
||||
# realize outputs before a parent is assigned to
|
||||
parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets)
|
||||
parents_assigns = dedup(xsi for x in lsi.assign_preloads if (xsi:=schedule_targets.get(uop_bufs[x])) and xsi is not lsi)
|
||||
for assign in parents_assigns:
|
||||
graph[lsi].append(assign)
|
||||
in_degree[assign] += 1
|
||||
# realize outputs after all parents are realized
|
||||
scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None)
|
||||
scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
|
||||
for x in scheduled_parents:
|
||||
graph[x].append(lsi)
|
||||
in_degree[lsi] += 1
|
||||
@@ -424,13 +417,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
queue = deque(lsi for lsi in prescheduled if in_degree[lsi] == 0)
|
||||
schedule: List[ScheduleItem] = []
|
||||
while queue:
|
||||
lsi = queue.popleft()
|
||||
schedule.append(si:=ScheduleItem(lsi.ast, tuple(b for u in lsi.ubufs if (b:=uop_bufs[u]).size != 0), lsi.metadata))
|
||||
schedule.append(si:=queue.popleft())
|
||||
for b in si.outputs: del lazybufs_to_realize[b].srcs # can only schedule once
|
||||
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
|
||||
if DEBUG >= 3: print(si)
|
||||
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
|
||||
for x in graph[lsi]:
|
||||
for x in graph[si]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user