mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
reorder simplifier and grouper logic in scheduler [pr] (#8861)
This commit is contained in:
@@ -203,7 +203,57 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
def save_process_replay() -> None:
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
# **** Schedule grouping
|
||||
# **** UOp realization
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
||||
|
||||
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
st = unwrap(view.st)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
|
||||
# early realize before expand
|
||||
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
|
||||
|
||||
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
|
||||
del ctx.realizes[b]
|
||||
return x.view(unwrap(view.st))
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
|
||||
# don't realize image to image casts
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
|
||||
fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
||||
])
|
||||
|
||||
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
ctx.allbufs[buf_uop] = view
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
|
||||
def uval(u:UOp) -> UOp:
|
||||
@@ -228,8 +278,9 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
|
||||
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
||||
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def group_realizes(ctx:ScheduleContext) -> None:
|
||||
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
|
||||
def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
|
||||
# start by adding uops that always realize
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: dict[UOp, UOp] = {}
|
||||
double_reduces: list[UOp] = []
|
||||
@@ -280,10 +331,28 @@ def group_realizes(ctx:ScheduleContext) -> None:
|
||||
for reduceop in double_reduces:
|
||||
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
|
||||
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
return ctx.realizes
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
# break the SINK into stores
|
||||
|
||||
# ** this is schedule level const folding
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if b not in ctx.realizes: return x # collapse BUFFER
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
||||
])
|
||||
|
||||
# **** schedule simplifier
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
if not all_int(x.shape): return None
|
||||
@@ -338,80 +407,6 @@ sym = symbolic_simple+PatternMatcher([
|
||||
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
||||
|
||||
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
st = unwrap(view.st)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
|
||||
# early realize before expand
|
||||
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
|
||||
|
||||
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
|
||||
del ctx.realizes[b]
|
||||
return x.view(unwrap(view.st))
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
|
||||
# don't realize image to image casts
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
|
||||
fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
||||
])
|
||||
|
||||
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
|
||||
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if b not in ctx.realizes: return x # collapse BUFFER
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
||||
])
|
||||
|
||||
# **** Schedule context builder
|
||||
|
||||
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
ctx.allbufs[buf_uop] = view
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
# **** movement ops
|
||||
|
||||
remove_movement_ops = merge_views+PatternMatcher([
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
@@ -420,6 +415,8 @@ remove_movement_ops = merge_views+PatternMatcher([
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
])
|
||||
|
||||
# **** schedule creation and toposort
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
||||
@@ -438,11 +435,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
|
||||
# add realizes
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# group realizes into kernels
|
||||
group_realizes(ctx)
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
# get realizes
|
||||
realize_map = group_realizes(sink, ctx)
|
||||
|
||||
# TODO: this should be the break between the "grouper" and the "linearizer"
|
||||
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
||||
@@ -451,7 +445,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
# create schedule items + map buffers to realized tensors
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
var_vals: dict[Variable, int] = {}
|
||||
for buf_uop,store in ctx.realizes.items():
|
||||
for buf_uop,store in realize_map.items():
|
||||
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
|
||||
prescheduled.append(schedule_uop(store.sink(), ctx, var_vals))
|
||||
# can only schedule once
|
||||
|
||||
Reference in New Issue
Block a user