diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1083240ea7..f5bc8612ee 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -203,7 +203,57 @@ if CAPTURE_PROCESS_REPLAY: def save_process_replay() -> None: for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) -# **** Schedule grouping +# **** UOp realization + +class UPatScheduled(UPat): + def __init__(self, *args, **kwargs): + super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) + +def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store + +def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: + st = unwrap(view.st) + # fold simple pads + if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): + return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src) + # early realize before expand + if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src) + # otherwise safety check pads + return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src) + +def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None: + if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None + del ctx.realizes[b] + return x.view(unwrap(view.st)) + +def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): + if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None + buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) + return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) + +do_realize = PatternMatcher([ + # always realize SINK parents + (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)), + # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW + (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize), + # realize before expand or unsafe pad ops + (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view), + # don't realize image to image casts + (UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)), + fold_img_cast), + # realize before COPY or BUFFER_VIEW + (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), + (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), + # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK + (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), +]) + +def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: + ctx.allbufs[buf_uop] = view + if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) + for x in op.base.src: + if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None +create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER def uval(u:UOp) -> UOp: @@ -228,8 +278,9 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r) recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache) -def group_realizes(ctx:ScheduleContext) -> None: - """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop""" +def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]: + # start by adding uops that always realize + sink = graph_rewrite(sink, do_realize+create_ctx, ctx) # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: dict[UOp, UOp] = {} double_reduces: list[UOp] = [] @@ -280,10 +331,28 @@ def group_realizes(ctx:ScheduleContext) -> None: for reduceop in double_reduces: top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce] + graph_rewrite(sink, break_sched, ctx) + return ctx.realizes -# **** Schedule creation and BFS toposort +# break the SINK into stores -# ** this is schedule level const folding +def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): + # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN + return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) + +def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): + if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m + if b not in ctx.realizes: return x # collapse BUFFER + ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) + return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) + +break_sched = PatternMatcher([ + # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it + (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), + (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), +]) + +# **** schedule simplifier def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: if not all_int(x.shape): return None @@ -338,80 +407,6 @@ sym = symbolic_simple+PatternMatcher([ if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) -# ** this decides which ops get realized - -class UPatScheduled(UPat): - def __init__(self, *args, **kwargs): - super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) - -def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store - -def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: - st = unwrap(view.st) - # fold simple pads - if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): - return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src) - # early realize before expand - if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src) - # otherwise safety check pads - return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src) - -def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None: - if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None - del ctx.realizes[b] - return x.view(unwrap(view.st)) - -def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None - buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) - return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) - -do_realize = PatternMatcher([ - # always realize SINK parents - (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)), - # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW - (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize), - # realize before expand or unsafe pad ops - (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view), - # don't realize image to image casts - (UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)), - fold_img_cast), - # realize before COPY or BUFFER_VIEW - (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK - (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), -]) - -# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp - -def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): - # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN - return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) - -def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): - if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m - if b not in ctx.realizes: return x # collapse BUFFER - ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) - return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) - -break_sched = PatternMatcher([ - # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it - (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), - (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), -]) - -# **** Schedule context builder - -def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: - ctx.allbufs[buf_uop] = view - if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) - for x in op.base.src: - if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None -create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) - -# **** movement ops - remove_movement_ops = merge_views+PatternMatcher([ # NOTE: movement ops are always applied to base (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), @@ -420,6 +415,8 @@ remove_movement_ops = merge_views+PatternMatcher([ lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), ]) +# **** schedule creation and toposort + @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={}) @@ -438,11 +435,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={}) - # add realizes - sink = graph_rewrite(sink, do_realize+create_ctx, ctx) - # group realizes into kernels - group_realizes(ctx) - graph_rewrite(sink, break_sched, ctx) + # get realizes + realize_map = group_realizes(sink, ctx) # TODO: this should be the break between the "grouper" and the "linearizer" # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign) @@ -451,7 +445,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] var_vals: dict[Variable, int] = {} - for buf_uop,store in ctx.realizes.items(): + for buf_uop,store in realize_map.items(): assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}" prescheduled.append(schedule_uop(store.sink(), ctx, var_vals)) # can only schedule once