assign toposort with big graph, bfs [pr] (#7242)

* assign toposort with big graph, bfs [pr]

* cycle

* merge 2

* filter bufs

* delete inputs
This commit is contained in:
qazal
2024-10-24 13:09:01 +03:00
committed by GitHub
parent 4d081eb560
commit fa5dc7857a
2 changed files with 30 additions and 38 deletions

View File

@@ -48,4 +48,4 @@ def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]

View File

@@ -25,6 +25,7 @@ class ScheduleItem:
ast: UOp
bufs: Tuple[Buffer, ...]
metadata: Tuple[Metadata, ...]
assign_preloads: Tuple[UOp, ...]
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
@@ -36,17 +37,6 @@ class ScheduleItem:
@functools.cached_property
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,)
@dataclass(frozen=True)
class LBScheduleItem:
ast: UOp
bufs: Tuple[LazyBuffer, ...]
ubufs: Tuple[UOp, ...]
metadata: Tuple[Metadata, ...]
@property
def outputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
@property
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
# *** UOp with VIEW (movementops) rewriting to UOp we can index ***
# ** helpers for doing movementops on uops
@@ -132,8 +122,10 @@ view_right = merge_views+PatternMatcher([
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: Dict[Variable, int]
assigned: Set[UOp]
sts: Set[ShapeTracker] = field(default_factory=set)
bufs: List[UOp] = field(default_factory=list)
assign_preloads: List[UOp] = field(default_factory=list)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
if (st:=unwrap(x.st)) in ctx.sts: return None
@@ -145,11 +137,15 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
if b in ctx.assigned: ctx.assign_preloads.append(b)
return x.replace(op=UOps.LOAD)
to_si = PatternMatcher([
(UPat(UOps.BUFFER, name="x"), _append_buf),
(UPat(UOps.VIEW, name="x"), _append_st_vars),
(UPat(UOps.PRELOAD, name="x"), lambda _,x: x.replace(op=UOps.LOAD)),
(UPat(UOps.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
(UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda _,x: x),
(UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda _,x: x),
])
@@ -157,8 +153,8 @@ to_si = PatternMatcher([
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
ret = graph_rewrite(sink, to_si, ctx)
PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals), ret))
ret = graph_rewrite(graph_rewrite(sink, to_si, ctx), append_bufs, ctx)
PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals, ctx.assigned), ret))
return ret
if getenv("RUN_PROCESS_REPLAY"):
@@ -168,18 +164,16 @@ if getenv("RUN_PROCESS_REPLAY"):
# *** List[LazyBuffer] lowering to ScheduleItem ***
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata],
cache:Dict[LazyBuffer, UOp]) -> UOp:
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata], cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, metadata, cache).view(buf.st)
cache[buf] = ret = to_uop(buf.base, outputs, buf_uops, metadata, cache).view(buf.st)
return ret
if buf.op is MetaOps.CONST: return buf_uops[buf.buffer]
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf)
return UOp(UOps.PRELOAD if buf.is_realized() else UOps.LOAD, dtype, (ubuf, buf.st.to_uop()))
src = tuple(to_uop(x, outputs, inputs, buf_uops, metadata, cache) for x in buf.srcs)
src = tuple(to_uop(x, outputs, buf_uops, metadata, cache) for x in buf.srcs)
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg)
@@ -191,26 +185,25 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu
if buf.metadata is not None: metadata[ret] = buf.metadata
return ret
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], uop_bufs:Dict[UOp, Buffer],
var_vals:Dict[Variable, int]) -> LBScheduleItem:
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], uop_bufs:Dict[UOp, Buffer], assigned:Set[UOp],
var_vals:Dict[Variable, int]) -> ScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
cache: Dict[LazyBuffer, UOp] = {}
inputs: List[LazyBuffer] = []
metadata: Dict[UOp, Metadata] = {}
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs))
to_uop(out, outs, buf_uops, metadata, cache)) for out in outs))
# assert cyclic dependency
for b,reads in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD, UOps.LOAD}), key=lambda x:x.src[0]):
for b,reads in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD, UOps.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
if not all_same([x.op for x in reads]):
raise RuntimeError(f"cycle detected in kernel.\nhelp: consider using .contiguous() to load the pre-assign version of {uop_bufs[b]}.")
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals))
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals, assigned))
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(ctx.bufs), tuple(dedup(metadata.values())))
return ScheduleItem(sink, tuple(b for u in ctx.bufs if (b:=uop_bufs[u]).size != 0), tuple(dedup(metadata.values())), tuple(ctx.assign_preloads))
# *** DAG creation: decide which LazyBuffers should realize ***
@@ -372,7 +365,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
buf_uops: Dict[Buffer, UOp] = {}
uop_bufs: Dict[UOp, Buffer] = {}
var_vals: Dict[Variable, int] = {}
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
assigned: Set[UOp] = set()
lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {}
for buf in realizes:
if buf.realized is None and buf.op is not MetaOps.CONST:
@@ -399,24 +392,24 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
buf_uops[buf.buffer] = uop
uop_bufs[uop] = buf.buffer
if buf.realized is None:
if buf.op is MetaOps.ASSIGN: assign_targets[buf.srcs[0]] = buf
if buf.op is MetaOps.ASSIGN: assigned.add(buf_uops[buf.buffer])
if buf.op is not MetaOps.CONST:output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
# preschedule all buffers in realizes
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, uop_bufs,
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, uop_bufs, assigned,
var_vals) for outs in output_groups.values()]
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int)
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
for lsi in prescheduled:
# realize outputs before a parent is assigned to
parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets)
parents_assigns = dedup(xsi for x in lsi.assign_preloads if (xsi:=schedule_targets.get(uop_bufs[x])) and xsi is not lsi)
for assign in parents_assigns:
graph[lsi].append(assign)
in_degree[assign] += 1
# realize outputs after all parents are realized
scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None)
scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
for x in scheduled_parents:
graph[x].append(lsi)
in_degree[lsi] += 1
@@ -424,13 +417,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
queue = deque(lsi for lsi in prescheduled if in_degree[lsi] == 0)
schedule: List[ScheduleItem] = []
while queue:
lsi = queue.popleft()
schedule.append(si:=ScheduleItem(lsi.ast, tuple(b for u in lsi.ubufs if (b:=uop_bufs[u]).size != 0), lsi.metadata))
schedule.append(si:=queue.popleft())
for b in si.outputs: del lazybufs_to_realize[b].srcs # can only schedule once
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
if DEBUG >= 3: print(si)
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
for x in graph[lsi]:
for x in graph[si]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)