|
|
|
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
|
|
|
|
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
|
|
|
|
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
|
|
|
|
|
from tinygrad.codegen.symbolic import symbolic_simple
|
|
|
|
|
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten
|
|
|
|
|
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
|
|
|
|
|
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
|
|
|
|
|
from tinygrad.dtype import ImageDType
|
|
|
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
|
|
|
@@ -87,7 +87,6 @@ class ScheduleContext:
|
|
|
|
|
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
|
|
|
|
var_vals: dict[Variable, int] = field(default_factory=dict)
|
|
|
|
|
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
|
|
|
|
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
|
|
|
|
|
|
|
|
|
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
|
|
|
|
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
|
|
|
|
@@ -230,19 +229,33 @@ def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
|
|
|
|
|
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
|
|
|
|
return ctx.realizes
|
|
|
|
|
|
|
|
|
|
# break the SINK into stores
|
|
|
|
|
# break the SINK into kernels
|
|
|
|
|
|
|
|
|
|
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class Kernel:
|
|
|
|
|
ast: UOp
|
|
|
|
|
metadata: tuple[Metadata, ...]
|
|
|
|
|
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
|
|
|
|
|
|
|
|
|
def create_kernel(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
|
|
|
|
if (m:=ctx.ops_metadata.get(b)) is not None: ctx.ops_metadata[x] = m
|
|
|
|
|
if b not in ctx.realizes: return x # collapse BUFFER
|
|
|
|
|
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
|
|
|
|
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
|
|
|
|
# KERNEL nodes become: ASSIGN(VIEW(BUFFER), KERNEL)
|
|
|
|
|
return b.view(ShapeTracker.from_shape(x.shape)).assign(UOp(Ops.KERNEL, src=st.src, arg=Kernel(x, (m,) if m is not None else ())))
|
|
|
|
|
|
|
|
|
|
break_sched = PatternMatcher([
|
|
|
|
|
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
|
|
|
|
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)),
|
|
|
|
|
lambda ctx,st,b: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, st.st.to_uop()))),
|
|
|
|
|
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
|
|
|
|
def append_to_kernel(ctx:ScheduleContext, x:UOp):
|
|
|
|
|
new_srcs: list[UOp] = []
|
|
|
|
|
new_metadata: dict[Metadata, None] = dict.fromkeys(x.arg.metadata)
|
|
|
|
|
for s in x.src:
|
|
|
|
|
if s.op is Ops.BUFFER or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL): new_srcs.append(s)
|
|
|
|
|
else:
|
|
|
|
|
new_srcs.extend(s.src)
|
|
|
|
|
if (m:=ctx.ops_metadata.get(s)) is not None: new_metadata[m] = None
|
|
|
|
|
return x.replace(src=n, arg=Kernel(x.arg.ast, tuple(new_metadata))) if (n:=tuple(dedup(new_srcs))) != x.src else None
|
|
|
|
|
|
|
|
|
|
create_kernels = merge_views+PatternMatcher([
|
|
|
|
|
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), create_kernel),
|
|
|
|
|
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
|
|
|
|
@@ -263,23 +276,8 @@ class ScheduleItem:
|
|
|
|
|
@functools.cached_property
|
|
|
|
|
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
|
|
|
|
|
|
|
|
|
def kernel_to_si(k:UOp) -> ScheduleItem:
|
|
|
|
|
assert k.op is Ops.KERNEL and isinstance(k.metadata, tuple), f"must be KERNEL {k}"
|
|
|
|
|
return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.metadata)
|
|
|
|
|
|
|
|
|
|
# **** Kernel creation
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class Kernel:
|
|
|
|
|
ast: UOp
|
|
|
|
|
metadata: tuple[Metadata, ...]
|
|
|
|
|
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class KernelContext:
|
|
|
|
|
var_vals: dict[Variable, int]
|
|
|
|
|
bufs: list[UOp] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
def apply_swizzle(u:UOp) -> UOp:
|
|
|
|
|
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
|
|
|
|
|
|
|
|
|
@@ -332,17 +330,13 @@ view_right = merge_views+PatternMatcher([
|
|
|
|
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def _append_st_vars(ctx:KernelContext, x:UOp) -> UOp|None:
|
|
|
|
|
def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
|
|
|
|
st = unwrap(x.st).simplify()
|
|
|
|
|
if any(x.op is Ops.BIND for x in st.vars()):
|
|
|
|
|
st, var_vals = st.unbind()
|
|
|
|
|
ctx.var_vals.update(var_vals)
|
|
|
|
|
ctx.update(var_vals)
|
|
|
|
|
return st.to_uop() if st != x.st else None
|
|
|
|
|
|
|
|
|
|
def _append_buf(ctx:KernelContext, x:UOp) -> UOp:
|
|
|
|
|
ctx.bufs.append(x)
|
|
|
|
|
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
|
|
|
|
|
|
|
|
|
def check_load_st(glbl:UOp, view:UOp):
|
|
|
|
|
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
|
|
|
|
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
|
|
|
|
@@ -354,44 +348,48 @@ def check_load_st(glbl:UOp, view:UOp):
|
|
|
|
|
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
|
|
|
|
|
|
|
|
|
fix_kernel_ops = PatternMatcher([
|
|
|
|
|
# BUFFER becomes DEFINE_GLOBAL
|
|
|
|
|
(UPat(Ops.BUFFER, name="x"), _append_buf),
|
|
|
|
|
# BIND in shapetracker becomes DEFINE_VAR
|
|
|
|
|
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
|
|
|
|
# remove SINK from COPY and BUFFER_VIEW
|
|
|
|
|
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
|
|
|
|
# remove CONTIGUOUS/ASSIGN/DEVICE/PRELOAD
|
|
|
|
|
# remove CONTIGUOUS/ASSIGN/DEVICE
|
|
|
|
|
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
|
|
|
|
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
|
|
|
|
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
|
|
|
|
(UPat(Ops.PRELOAD, name="x"), lambda x: x.replace(op=Ops.LOAD)),
|
|
|
|
|
# no ImageDType after load
|
|
|
|
|
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
|
|
|
|
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
|
|
|
|
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def load_buf(ctx:list[UOp], x:UOp):
|
|
|
|
|
if x.base not in ctx: ctx.append(x.base)
|
|
|
|
|
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.base.size), (), ctx.index(x.base)), unwrap(x.st).to_uop()))
|
|
|
|
|
|
|
|
|
|
add_buffer_ops = PatternMatcher([
|
|
|
|
|
# LOAD
|
|
|
|
|
(UPat(Ops.ASSIGN, src=(UPat.var("x"), UPat(Ops.KERNEL))), load_buf),
|
|
|
|
|
(UPat(Ops.BUFFER, name="x"), load_buf),
|
|
|
|
|
# STORE (except for COPY/BUFFER_VIEW)
|
|
|
|
|
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
|
|
|
|
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
|
|
|
|
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
|
|
|
|
ctx[var.replace(src=())] = val.arg
|
|
|
|
|
return var
|
|
|
|
|
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
|
|
|
|
|
|
|
|
|
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp:
|
|
|
|
|
def schedule_uop(sink:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
|
|
|
|
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
|
|
|
|
# start by adding buffer ops
|
|
|
|
|
ast = graph_rewrite(sink.src[1].arg.ast.sink(), add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
|
|
|
|
# unbind_vars + push views to edges
|
|
|
|
|
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=ctx.var_vals), view_right)
|
|
|
|
|
# deal with ASSIGN
|
|
|
|
|
if len(ctx.assigns) != 0:
|
|
|
|
|
assign_preloads = ctx.preloads[pre.src[0].buf_uop.buffer]
|
|
|
|
|
for x in list(sink.toposort)[::-1]:
|
|
|
|
|
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
|
|
|
|
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
|
|
|
|
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
|
|
|
|
if x.op is Ops.PRELOAD: assign_preloads[x.buf_uop] = None
|
|
|
|
|
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=ctx.var_vals), view_right)
|
|
|
|
|
# fix_kernel_ops
|
|
|
|
|
sink = graph_rewrite(sink, fix_kernel_ops, si_ctx:=KernelContext(ctx.var_vals))
|
|
|
|
|
# NOTE: we only add the metadata for fused tensors
|
|
|
|
|
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
|
|
|
|
|
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(sink, metadata))
|
|
|
|
|
ast = graph_rewrite(ast, fix_kernel_ops, ctx.var_vals)
|
|
|
|
|
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
|
|
|
|
|
|
|
|
|
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
|
|
|
|
if CAPTURE_PROCESS_REPLAY:
|
|
|
|
|
@@ -399,11 +397,6 @@ if CAPTURE_PROCESS_REPLAY:
|
|
|
|
|
def save_process_replay():
|
|
|
|
|
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
|
|
|
|
|
|
|
|
|
create_kernels = PatternMatcher([
|
|
|
|
|
(UPat(Ops.SINK, name="x"), lambda ctx,x: x.replace(src=tuple(schedule_uop(s.sink(), ctx) for s in x.src))
|
|
|
|
|
if any(s.op is not Ops.KERNEL for s in x.src) else None),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# **** schedule creation and toposort
|
|
|
|
|
|
|
|
|
|
@track_rewrites(named=True)
|
|
|
|
|
@@ -439,47 +432,50 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
|
|
|
|
else: becomes_map[k] = v
|
|
|
|
|
elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
|
|
|
|
|
|
|
|
|
# create kernels, TODO: this should use the SINK from tensor_map
|
|
|
|
|
# create kernels
|
|
|
|
|
if len(realize_map) == 0: return [], {}, becomes_map
|
|
|
|
|
graph_rewrite(sink, break_sched, ctx)
|
|
|
|
|
sched_sink = graph_rewrite(UOp.sink(*realize_map.values()), create_kernels, ctx)
|
|
|
|
|
sched_sink = graph_rewrite(sink, create_kernels, ctx)
|
|
|
|
|
type_verify(list(sched_sink.toposort), kernel_spec)
|
|
|
|
|
|
|
|
|
|
# TODO: this should be the break between the "grouper" and the "linearizer"
|
|
|
|
|
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
|
|
|
|
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
|
|
|
|
|
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
|
|
|
|
kernel_assign: dict[UOp, UOp] = {}
|
|
|
|
|
assign_rep: dict[UOp, UOp] = {}
|
|
|
|
|
for u in sched_sink.toposort:
|
|
|
|
|
if u.op is not Ops.ASSIGN: continue
|
|
|
|
|
kernel_assign[u.buf_uop] = u
|
|
|
|
|
for s in u.src[1].src:
|
|
|
|
|
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
|
|
|
|
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
|
|
|
|
|
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
|
|
|
|
|
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
|
|
|
|
if assign_rep: sched_sink = sched_sink.substitute(assign_rep)
|
|
|
|
|
# display the final graph
|
|
|
|
|
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
|
|
|
|
|
|
|
|
|
|
# convert kernels to ScheduleItem
|
|
|
|
|
prescheduled = [kernel_to_si(k) for k in sched_sink.src]
|
|
|
|
|
# add ScheduleItem children
|
|
|
|
|
# TODO: this should construct the graph directly from the sched_sink
|
|
|
|
|
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
|
|
|
|
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
|
|
|
|
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
|
|
|
|
|
for si in prescheduled:
|
|
|
|
|
# realize outputs before a parent is assigned to
|
|
|
|
|
parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) is not None and xsi is not si)
|
|
|
|
|
for assign in parents_assigns:
|
|
|
|
|
graph[si].append(assign)
|
|
|
|
|
in_degree[assign] += 1
|
|
|
|
|
# realize outputs after all parents are realized
|
|
|
|
|
scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
|
|
|
|
|
for x in scheduled_parents:
|
|
|
|
|
graph[x].append(si)
|
|
|
|
|
in_degree[si] += 1
|
|
|
|
|
# final toposort (bfs)
|
|
|
|
|
children: dict[UOp, list[UOp]] = {}
|
|
|
|
|
in_degree: dict[UOp, int] = {}
|
|
|
|
|
for u in sched_sink.toposort:
|
|
|
|
|
if u.op is not Ops.ASSIGN: continue
|
|
|
|
|
in_degree[u] = 0
|
|
|
|
|
for s in u.src[1].src:
|
|
|
|
|
if s.op is not Ops.ASSIGN: continue
|
|
|
|
|
children.setdefault(s, []).append(u)
|
|
|
|
|
in_degree[u] += 1
|
|
|
|
|
|
|
|
|
|
# do BFS
|
|
|
|
|
queue = deque(si for si in prescheduled if in_degree[si] == 0)
|
|
|
|
|
queue = deque(k for k,v in in_degree.items() if v == 0)
|
|
|
|
|
schedule: list[ScheduleItem] = []
|
|
|
|
|
while queue:
|
|
|
|
|
schedule.append(si:=queue.popleft())
|
|
|
|
|
u = queue.popleft()
|
|
|
|
|
schedule.append(si:=schedule_uop(u, ctx))
|
|
|
|
|
# NOTE: incrementing output buffer refcounts is required by the memory planner and JIT
|
|
|
|
|
for out in si.outputs: out.ref(1)
|
|
|
|
|
for x in graph[si]:
|
|
|
|
|
for x in children.get(u, []):
|
|
|
|
|
in_degree[x] -= 1
|
|
|
|
|
if in_degree[x] == 0: queue.append(x)
|
|
|
|
|
|
|
|
|
|
# confirm everything was scheduled correctly
|
|
|
|
|
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
|
|
|
|
if len(schedule) != (groups:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
|
|
|
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
|
|
|
|
# capture process replay
|
|
|
|
|
if CAPTURE_PROCESS_REPLAY:
|
|
|
|
|
|