|
|
|
|
@@ -66,9 +66,6 @@ sym = symbolic_simple+PatternMatcher([
|
|
|
|
|
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
|
|
|
|
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
|
|
|
|
|
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
|
|
|
|
|
# assigns last
|
|
|
|
|
(UPat(GroupOp.All, name="root"),
|
|
|
|
|
lambda root: root.replace(src=n) if (n:=tuple(sorted(root.src, key=lambda x:0 if x.op is Ops.ASSIGN else -1))) != root.src else None),
|
|
|
|
|
# remove CONST/BIND/BUFFER/VIEW from SINK
|
|
|
|
|
(UPat(Ops.SINK, name="root"),
|
|
|
|
|
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
|
|
|
|
@@ -93,6 +90,7 @@ class ScheduleContext:
|
|
|
|
|
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
|
|
|
|
var_vals: dict[Variable, int] = field(default_factory=dict)
|
|
|
|
|
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
|
|
|
|
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
|
|
|
|
|
|
|
|
|
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
|
|
|
|
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
|
|
|
|
@@ -245,11 +243,12 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
|
|
|
|
|
|
|
|
|
break_sched = PatternMatcher([
|
|
|
|
|
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
|
|
|
|
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), lambda st,b: UOp(Ops.LOAD, b.dtype.base, (b, st.st.to_uop()))),
|
|
|
|
|
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)),
|
|
|
|
|
lambda ctx,st,b: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, st.st.to_uop()))),
|
|
|
|
|
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# **** ScheduleItem creation (TODO: replace ScheduleItem with the KERNEL UOp)
|
|
|
|
|
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class ScheduleItem:
|
|
|
|
|
@@ -267,6 +266,18 @@ class ScheduleItem:
|
|
|
|
|
@functools.cached_property
|
|
|
|
|
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
|
|
|
|
|
|
|
|
|
def kernel_to_si(k:UOp) -> ScheduleItem:
|
|
|
|
|
assert k.op is Ops.KERNEL and isinstance(k.metadata, tuple), f"must be KERNEL {k}"
|
|
|
|
|
return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.metadata)
|
|
|
|
|
|
|
|
|
|
# **** Kernel creation
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class Kernel:
|
|
|
|
|
ast: UOp
|
|
|
|
|
metadata: tuple[Metadata, ...]
|
|
|
|
|
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class KernelContext:
|
|
|
|
|
var_vals: dict[Variable, int]
|
|
|
|
|
@@ -345,18 +356,21 @@ def check_load_st(glbl:UOp, view:UOp):
|
|
|
|
|
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
|
|
|
|
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
|
|
|
|
|
|
|
|
|
fix_kernel_ops = PatternMatcher([
|
|
|
|
|
# BUFFER becomes DEFINE_GLOBAL
|
|
|
|
|
to_si = PatternMatcher([
|
|
|
|
|
# BUFFER -> DEFINE_GLOBAL
|
|
|
|
|
(UPat(Ops.BUFFER, name="x"), _append_buf),
|
|
|
|
|
# BIND in shapetracker becomes DEFINE_VAR
|
|
|
|
|
# simplify and unbind the final VIEWs
|
|
|
|
|
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
|
|
|
|
# remove SINK from COPY and BUFFER_VIEW
|
|
|
|
|
# don't need SINK on COPY or BUFFER_VIEW
|
|
|
|
|
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
|
|
|
|
# remove CONTIGUOUS/ASSIGN/DEVICE
|
|
|
|
|
# don't need contiguous or assign anymore
|
|
|
|
|
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
|
|
|
|
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
|
|
|
|
# don't need DEVICE anymore
|
|
|
|
|
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
|
|
|
|
# no ImageDType after load
|
|
|
|
|
# PRELOAD becomes LOAD
|
|
|
|
|
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
|
|
|
|
# once images are loaded they become the base dtype
|
|
|
|
|
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
|
|
|
|
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
|
|
|
|
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
|
|
|
|
@@ -367,14 +381,22 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
|
|
|
|
return var
|
|
|
|
|
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
|
|
|
|
|
|
|
|
|
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
|
|
|
|
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp:
|
|
|
|
|
# unbind_vars + push views to edges
|
|
|
|
|
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=ctx.var_vals), view_right)
|
|
|
|
|
# fix_kernel_ops
|
|
|
|
|
sink = graph_rewrite(sink, fix_kernel_ops, si_ctx:=KernelContext(ctx.var_vals))
|
|
|
|
|
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
|
|
|
|
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(ctx.var_vals))
|
|
|
|
|
# deal with ASSIGN
|
|
|
|
|
if len(ctx.assigns) != 0:
|
|
|
|
|
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
|
|
|
|
|
for x in list(sink.toposort)[::-1]:
|
|
|
|
|
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
|
|
|
|
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
|
|
|
|
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
|
|
|
|
if x.op is Ops.PRELOAD: assign_preloads[x.buf_uop] = None
|
|
|
|
|
# NOTE: we only add the metadata for fused tensors
|
|
|
|
|
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
|
|
|
|
|
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs), metadata)
|
|
|
|
|
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
|
|
|
|
|
|
|
|
|
|
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
|
|
|
|
if CAPTURE_PROCESS_REPLAY:
|
|
|
|
|
@@ -382,35 +404,9 @@ if CAPTURE_PROCESS_REPLAY:
|
|
|
|
|
def save_process_replay():
|
|
|
|
|
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class Kernel:
|
|
|
|
|
ast: UOp
|
|
|
|
|
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op}>"
|
|
|
|
|
|
|
|
|
|
# NOTE: realizes become ASSIGN(BUFFER, KERNEL) in the schedule graph
|
|
|
|
|
def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink())))
|
|
|
|
|
def is_kernel(u:UOp) -> bool: return u.op is Ops.ASSIGN and u.src[1].op is Ops.KERNEL
|
|
|
|
|
|
|
|
|
|
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.BUFFER}
|
|
|
|
|
def append_to_kernel(ctx:dict[UOp, UOp], x:UOp):
|
|
|
|
|
new_src: list[UOp] = []
|
|
|
|
|
for s in x.src:
|
|
|
|
|
# these ops never fuse
|
|
|
|
|
if s.op in DONT_PLACE_IN_KERNEL or is_kernel(s): pass
|
|
|
|
|
# otherwise check if we're realizing it
|
|
|
|
|
elif is_scheduled(s) and s.buf_uop in ctx: pass
|
|
|
|
|
else:
|
|
|
|
|
# fuse this op!
|
|
|
|
|
new_src.extend(uval(s).src if is_scheduled(s) else s.src)
|
|
|
|
|
continue
|
|
|
|
|
# don't fuse this op
|
|
|
|
|
new_src.append(s)
|
|
|
|
|
return x.replace(src=n) if (n:=tuple(dedup(new_src))) != x.src else None
|
|
|
|
|
|
|
|
|
|
create_kernels = PatternMatcher([
|
|
|
|
|
(UPat(Ops.SINK, name="x"), lambda ctx,x: x.replace(src=tuple(init_kernel(ctx, s) for s in x.src))
|
|
|
|
|
if any(not is_kernel(s) for s in x.src) else None),
|
|
|
|
|
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
|
|
|
|
(UPat(Ops.SINK, name="x"), lambda ctx,x: x.replace(src=tuple(schedule_uop(s.sink(), ctx) for s in x.src))
|
|
|
|
|
if any(s.op is not Ops.KERNEL for s in x.src) else None),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# **** schedule creation and toposort
|
|
|
|
|
@@ -448,59 +444,44 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
|
|
|
|
becomes_map[tensor_uop] = tensor_uop.src[0] if tensor_uop.op is Ops.ASSIGN else buf_uop.reshape(tensor_uop.shape)
|
|
|
|
|
buf_uop.buffer.ref(1)
|
|
|
|
|
|
|
|
|
|
# break the sink into kernels
|
|
|
|
|
# create kernels, TODO: this should use the SINK from tensor_map
|
|
|
|
|
graph_rewrite(sink, break_sched, ctx)
|
|
|
|
|
# create the kernel graph
|
|
|
|
|
sched_sink = sink
|
|
|
|
|
kernel_assign: dict[UOp, UOp] = {}
|
|
|
|
|
before_assign: dict[UOp, dict[UOp, UOp]] = {}
|
|
|
|
|
while 1:
|
|
|
|
|
sched_sink = graph_rewrite(sched_sink, create_kernels, realize_map)
|
|
|
|
|
rep: dict[UOp, UOp] = {}
|
|
|
|
|
for u in sched_sink.toposort:
|
|
|
|
|
if not is_kernel(u): continue
|
|
|
|
|
kernel_assign[u.buf_uop] = u
|
|
|
|
|
for s in u.src[1].src:
|
|
|
|
|
if s.op is Ops.BUFFER and s is not u.buf_uop: before_assign.setdefault(s, {})[u.buf_uop] = u
|
|
|
|
|
if s.op in DONT_PLACE_IN_KERNEL or is_kernel(s): continue
|
|
|
|
|
# otherwise it becomes a new kernel
|
|
|
|
|
rep[s] = init_kernel(realize_map, s)
|
|
|
|
|
if len(rep) == 0: break
|
|
|
|
|
sched_sink = sched_sink.substitute(rep)
|
|
|
|
|
sched_sink = graph_rewrite(UOp.sink(*realize_map.values()), create_kernels, ctx)
|
|
|
|
|
type_verify(list(sched_sink.toposort), kernel_spec)
|
|
|
|
|
|
|
|
|
|
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
|
|
|
|
assign_deps: dict[UOp, UOp] = {}
|
|
|
|
|
for k,v in kernel_assign.items():
|
|
|
|
|
if (deps:=before_assign.get(k)) is None: continue
|
|
|
|
|
for x in deps.values():
|
|
|
|
|
if any(xp.op is Ops.ASSIGN and xp.buf_uop is k for xp in x.toposort):
|
|
|
|
|
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
|
|
|
|
|
assign_deps[v] = v.replace(src=v.src+tuple(deps.values()))
|
|
|
|
|
if assign_deps: sched_sink = sched_sink.substitute(assign_deps)
|
|
|
|
|
# TODO: this should be the break between the "grouper" and the "linearizer"
|
|
|
|
|
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
|
|
|
|
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
|
|
|
|
|
|
|
|
|
|
# final toposort (bfs)
|
|
|
|
|
children: dict[UOp, list[UOp]] = {}
|
|
|
|
|
in_degree: dict[UOp, int] = {}
|
|
|
|
|
for u in sched_sink.toposort:
|
|
|
|
|
if u.op is not Ops.ASSIGN: continue
|
|
|
|
|
in_degree[u] = 0
|
|
|
|
|
for s in u.src[1].src:
|
|
|
|
|
if s.op is not Ops.ASSIGN: continue
|
|
|
|
|
children.setdefault(s, []).append(u)
|
|
|
|
|
in_degree[u] += 1
|
|
|
|
|
# convert kernels to ScheduleItem
|
|
|
|
|
prescheduled = [kernel_to_si(k) for k in sched_sink.src]
|
|
|
|
|
# add ScheduleItem children
|
|
|
|
|
# TODO: this should construct the graph directly from the sched_sink
|
|
|
|
|
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
|
|
|
|
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
|
|
|
|
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
|
|
|
|
|
for si in prescheduled:
|
|
|
|
|
# realize outputs before a parent is assigned to
|
|
|
|
|
parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) is not None and xsi is not si)
|
|
|
|
|
for assign in parents_assigns:
|
|
|
|
|
graph[si].append(assign)
|
|
|
|
|
in_degree[assign] += 1
|
|
|
|
|
# realize outputs after all parents are realized
|
|
|
|
|
scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
|
|
|
|
|
for x in scheduled_parents:
|
|
|
|
|
graph[x].append(si)
|
|
|
|
|
in_degree[si] += 1
|
|
|
|
|
|
|
|
|
|
queue = deque(k for k,v in in_degree.items() if v == 0)
|
|
|
|
|
# do BFS
|
|
|
|
|
queue = deque(si for si in prescheduled if in_degree[si] == 0)
|
|
|
|
|
schedule: list[ScheduleItem] = []
|
|
|
|
|
while queue:
|
|
|
|
|
u = queue.popleft()
|
|
|
|
|
schedule.append(schedule_uop(u.src[1].arg.ast, ctx))
|
|
|
|
|
for x in children.get(u, []):
|
|
|
|
|
schedule.append(si:=queue.popleft())
|
|
|
|
|
for x in graph[si]:
|
|
|
|
|
in_degree[x] -= 1
|
|
|
|
|
if in_degree[x] == 0: queue.append(x)
|
|
|
|
|
|
|
|
|
|
# confirm everything was scheduled correctly
|
|
|
|
|
if len(schedule) != (groups:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
|
|
|
|
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
|
|
|
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
|
|
|
|
# capture process replay
|
|
|
|
|
if CAPTURE_PROCESS_REPLAY:
|
|
|
|
|
|