reorder simplifier and grouper logic in scheduler [pr] (#8861)

This commit is contained in:
qazal
2025-02-02 10:19:52 -05:00
committed by GitHub
parent 83a904aaad
commit d64af3c884

View File

@@ -203,7 +203,57 @@ if CAPTURE_PROCESS_REPLAY:
def save_process_replay() -> None:
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** Schedule grouping
# **** UOp realization
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs):
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
del ctx.realizes[b]
return x.view(unwrap(view.st))
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
# don't realize image to image casts
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
])
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
ctx.allbufs[buf_uop] = view
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
def uval(u:UOp) -> UOp:
@@ -228,8 +278,9 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
def group_realizes(ctx:ScheduleContext) -> None:
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
# start by adding uops that always realize
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
@@ -280,10 +331,28 @@ def group_realizes(ctx:ScheduleContext) -> None:
for reduceop in double_reduces:
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
graph_rewrite(sink, break_sched, ctx)
return ctx.realizes
# **** Schedule creation and BFS toposort
# break the SINK into stores
# ** this is schedule level const folding
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
if b not in ctx.realizes: return x # collapse BUFFER
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
break_sched = PatternMatcher([
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** schedule simplifier
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
if not all_int(x.shape): return None
@@ -338,80 +407,6 @@ sym = symbolic_simple+PatternMatcher([
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
# ** this decides which ops get realized
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs):
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
del ctx.realizes[b]
return x.view(unwrap(view.st))
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
# don't realize image to image casts
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
])
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
if b not in ctx.realizes: return x # collapse BUFFER
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
break_sched = PatternMatcher([
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** Schedule context builder
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
ctx.allbufs[buf_uop] = view
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
# **** movement ops
remove_movement_ops = merge_views+PatternMatcher([
# NOTE: movement ops are always applied to base
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
@@ -420,6 +415,8 @@ remove_movement_ops = merge_views+PatternMatcher([
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
])
# **** schedule creation and toposort
@track_rewrites(named=True)
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
@@ -438,11 +435,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
# add BUFFER uops
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
# add realizes
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# group realizes into kernels
group_realizes(ctx)
graph_rewrite(sink, break_sched, ctx)
# get realizes
realize_map = group_realizes(sink, ctx)
# TODO: this should be the break between the "grouper" and the "linearizer"
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
@@ -451,7 +445,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# create schedule items + map buffers to realized tensors
prescheduled: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
for buf_uop,store in ctx.realizes.items():
for buf_uop,store in realize_map.items():
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
prescheduled.append(schedule_uop(store.sink(), ctx, var_vals))
# can only schedule once