revert scheduler change (#9019)

* Revert "cleanup ast rewriter [pr] (#9012)"

This reverts commit bf0bcb2d5a.

* Revert "kernel op cleanups + use ScheduleItem [pr] (#9009)"

This reverts commit c52cd2b437.

* Revert "construct the schedule sink 2 (#8925)"

This reverts commit cfd3db7862.
This commit is contained in:
George Hotz
2025-02-11 11:34:12 +08:00
committed by GitHub
parent 16e9e4db37
commit fb698920f1
6 changed files with 75 additions and 96 deletions

View File

@@ -203,7 +203,6 @@ class TestAssign(unittest.TestCase):
np.testing.assert_equal(b0.numpy(), 128)
np.testing.assert_equal(b1.numpy(), 608)
@unittest.skip("TODO: bring this assert back")
def test_crossunder_assign(self):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaisesRegex(RuntimeError, "cycle"):

View File

@@ -10,7 +10,7 @@ from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
from tinygrad.spec import spec
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.schedule import fix_kernel_ops
from tinygrad.engine.schedule import to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.rewriter import full_graph_rewrite, sym
@@ -487,7 +487,7 @@ class TestIndexingOrdering(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "rewriter.py")
self.assertEqual(fix_kernel_ops.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(Ops.CONST, dtypes.bool)

View File

@@ -66,9 +66,6 @@ sym = symbolic_simple+PatternMatcher([
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
# assigns last
(UPat(GroupOp.All, name="root"),
lambda root: root.replace(src=n) if (n:=tuple(sorted(root.src, key=lambda x:0 if x.op is Ops.ASSIGN else -1))) != root.src else None),
# remove CONST/BIND/BUFFER/VIEW from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
@@ -93,6 +90,7 @@ class ScheduleContext:
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
var_vals: dict[Variable, int] = field(default_factory=dict)
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
@@ -245,11 +243,12 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
break_sched = PatternMatcher([
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), lambda st,b: UOp(Ops.LOAD, b.dtype.base, (b, st.st.to_uop()))),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)),
lambda ctx,st,b: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, st.st.to_uop()))),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** ScheduleItem creation (TODO: replace ScheduleItem with the KERNEL UOp)
# **** convert Kernel to a ScheduleItem (for legacy reasons)
@dataclass(frozen=True)
class ScheduleItem:
@@ -267,6 +266,18 @@ class ScheduleItem:
@functools.cached_property
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
def kernel_to_si(k:UOp) -> ScheduleItem:
assert k.op is Ops.KERNEL and isinstance(k.metadata, tuple), f"must be KERNEL {k}"
return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.metadata)
# **** Kernel creation
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...]
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
@dataclass(frozen=True)
class KernelContext:
var_vals: dict[Variable, int]
@@ -345,18 +356,21 @@ def check_load_st(glbl:UOp, view:UOp):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# BUFFER becomes DEFINE_GLOBAL
to_si = PatternMatcher([
# BUFFER -> DEFINE_GLOBAL
(UPat(Ops.BUFFER, name="x"), _append_buf),
# BIND in shapetracker becomes DEFINE_VAR
# simplify and unbind the final VIEWs
(UPat(Ops.VIEW, name="x"), _append_st_vars),
# remove SINK from COPY and BUFFER_VIEW
# don't need SINK on COPY or BUFFER_VIEW
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# remove CONTIGUOUS/ASSIGN/DEVICE
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
# don't need DEVICE anymore
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
# no ImageDType after load
# PRELOAD becomes LOAD
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
# once images are loaded they become the base dtype
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
@@ -367,14 +381,22 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
return var
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp:
# unbind_vars + push views to edges
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=ctx.var_vals), view_right)
# fix_kernel_ops
sink = graph_rewrite(sink, fix_kernel_ops, si_ctx:=KernelContext(ctx.var_vals))
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(ctx.var_vals))
# deal with ASSIGN
if len(ctx.assigns) != 0:
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
for x in list(sink.toposort)[::-1]:
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
# PRELOAD tells the toposort this kernel should run before ASSIGN
if x.op is Ops.PRELOAD: assign_preloads[x.buf_uop] = None
# NOTE: we only add the metadata for fused tensors
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs), metadata)
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@@ -382,35 +404,9 @@ if CAPTURE_PROCESS_REPLAY:
def save_process_replay():
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
@dataclass(frozen=True)
class Kernel:
ast: UOp
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op}>"
# NOTE: realizes become ASSIGN(BUFFER, KERNEL) in the schedule graph
def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink())))
def is_kernel(u:UOp) -> bool: return u.op is Ops.ASSIGN and u.src[1].op is Ops.KERNEL
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.BUFFER}
def append_to_kernel(ctx:dict[UOp, UOp], x:UOp):
new_src: list[UOp] = []
for s in x.src:
# these ops never fuse
if s.op in DONT_PLACE_IN_KERNEL or is_kernel(s): pass
# otherwise check if we're realizing it
elif is_scheduled(s) and s.buf_uop in ctx: pass
else:
# fuse this op!
new_src.extend(uval(s).src if is_scheduled(s) else s.src)
continue
# don't fuse this op
new_src.append(s)
return x.replace(src=n) if (n:=tuple(dedup(new_src))) != x.src else None
create_kernels = PatternMatcher([
(UPat(Ops.SINK, name="x"), lambda ctx,x: x.replace(src=tuple(init_kernel(ctx, s) for s in x.src))
if any(not is_kernel(s) for s in x.src) else None),
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
(UPat(Ops.SINK, name="x"), lambda ctx,x: x.replace(src=tuple(schedule_uop(s.sink(), ctx) for s in x.src))
if any(s.op is not Ops.KERNEL for s in x.src) else None),
])
# **** schedule creation and toposort
@@ -448,59 +444,44 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
becomes_map[tensor_uop] = tensor_uop.src[0] if tensor_uop.op is Ops.ASSIGN else buf_uop.reshape(tensor_uop.shape)
buf_uop.buffer.ref(1)
# break the sink into kernels
# create kernels, TODO: this should use the SINK from tensor_map
graph_rewrite(sink, break_sched, ctx)
# create the kernel graph
sched_sink = sink
kernel_assign: dict[UOp, UOp] = {}
before_assign: dict[UOp, dict[UOp, UOp]] = {}
while 1:
sched_sink = graph_rewrite(sched_sink, create_kernels, realize_map)
rep: dict[UOp, UOp] = {}
for u in sched_sink.toposort:
if not is_kernel(u): continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
if s.op is Ops.BUFFER and s is not u.buf_uop: before_assign.setdefault(s, {})[u.buf_uop] = u
if s.op in DONT_PLACE_IN_KERNEL or is_kernel(s): continue
# otherwise it becomes a new kernel
rep[s] = init_kernel(realize_map, s)
if len(rep) == 0: break
sched_sink = sched_sink.substitute(rep)
sched_sink = graph_rewrite(UOp.sink(*realize_map.values()), create_kernels, ctx)
type_verify(list(sched_sink.toposort), kernel_spec)
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
assign_deps: dict[UOp, UOp] = {}
for k,v in kernel_assign.items():
if (deps:=before_assign.get(k)) is None: continue
for x in deps.values():
if any(xp.op is Ops.ASSIGN and xp.buf_uop is k for xp in x.toposort):
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
assign_deps[v] = v.replace(src=v.src+tuple(deps.values()))
if assign_deps: sched_sink = sched_sink.substitute(assign_deps)
# TODO: this should be the break between the "grouper" and the "linearizer"
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
# final toposort (bfs)
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort:
if u.op is not Ops.ASSIGN: continue
in_degree[u] = 0
for s in u.src[1].src:
if s.op is not Ops.ASSIGN: continue
children.setdefault(s, []).append(u)
in_degree[u] += 1
# convert kernels to ScheduleItem
prescheduled = [kernel_to_si(k) for k in sched_sink.src]
# add ScheduleItem children
# TODO: this should construct the graph directly from the sched_sink
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
for si in prescheduled:
# realize outputs before a parent is assigned to
parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) is not None and xsi is not si)
for assign in parents_assigns:
graph[si].append(assign)
in_degree[assign] += 1
# realize outputs after all parents are realized
scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
for x in scheduled_parents:
graph[x].append(si)
in_degree[si] += 1
queue = deque(k for k,v in in_degree.items() if v == 0)
# do BFS
queue = deque(si for si in prescheduled if in_degree[si] == 0)
schedule: list[ScheduleItem] = []
while queue:
u = queue.popleft()
schedule.append(schedule_uop(u.src[1].arg.ast, ctx))
for x in children.get(u, []):
schedule.append(si:=queue.popleft())
for x in graph[si]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled correctly
if len(schedule) != (groups:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
# capture process replay
if CAPTURE_PROCESS_REPLAY:

View File

@@ -94,7 +94,7 @@ class MathTrait(SimpleMathTrait):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto() # noqa: E702
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702
# TODO: empty continues to exist because of tensor
EMPTY = auto()
@@ -163,7 +163,7 @@ class GroupOp:
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
# BinaryOps that can be flipped
@@ -500,7 +500,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
@property
def metadata(self) -> Metadata|None: return all_metadata.get(self, None)
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
# *** uop movement ops ***

View File

@@ -126,9 +126,8 @@ spec = PatternMatcher([
kernel_spec = PatternMatcher([
(UPat(Ops.DEVICE, src=()), lambda: True),
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),)), lambda: True),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
# NOTE: assign always has a (BUF, KERNEL), it can also optionally depend on other assigns
(UPat(Ops.ASSIGN, src=[UPat(Ops.BUFFER), UPat((Ops.KERNEL, Ops.ASSIGN))], allow_any_len=True), lambda: True),
# TODO: currently kernel only has buffer parents, this is incomplete. it should be BUFFER and ASSIGN
(UPat(Ops.KERNEL, src=UPat(Ops.BUFFER)), lambda: True),
])
# *** this is the UOp shape spec ***

View File

@@ -9,7 +9,7 @@ from tinygrad.codegen.kernel import Kernel
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",