From 5b2c03e86553021e17543c3530343df92089d506 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:29:14 -0500 Subject: [PATCH] defer realize folding to kernel splitting [pr] (#7849) * defer realize folding to schedule breaking [pr] * this is init * p2 * need to lookup edges * refactor image cast folding [pr] * Ops.LOAD diff * image works * refactor can_pad * fix fold_img_cast --- tinygrad/codegen/kernel.py | 2 +- tinygrad/engine/schedule.py | 40 +++++++++++++++++++------------------ tinygrad/ops.py | 6 +++++- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 3d380308d4..d23a775860 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -441,7 +441,7 @@ class Kernel: check(not self.vars, "does not work with symbolic shape") check(axis < self.first_upcast, "cannot pad upcasted") # ok to pad SUM if all parent ALU ops have f(0) = 0 - if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}") + if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, set()), f"cannot pad {r}") padded = False for i,st in enumerate(self.sts): if (s:=st.shape[axis]) == 1: continue # reduced diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ba06dca7b7..14fd5a594f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -339,47 +339,51 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO # **** Schedule creation and BFS toposort -def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp: - ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store) - return UOp(Ops.LOAD, base.dtype, (b, st.to_uop())) +def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: + ctx[b] = to_store + return None -def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]: +def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None: if to_store.op in {Ops.CONST, Ops.BIND}: return None base_shape = unwrap(base.st).shape st = unwrap(view.st) # fold simple pads if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])): - return None if can_pad(base) else realize(ctx, b, to_store, base).view(st) + return None if can_pad(base, ctx, set()) else realize(ctx, b, to_store, base) # early realize before expand - if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st) + if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base) # otherwise safety check pads - return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st) + return None if (all(v.mask is None for v in st.views) or can_pad(base, ctx, set())) else realize(ctx, b, to_store, base) -def fold_img_cast(ctx, xb:UOp, view:UOp, **kwargs) -> Optional[UOp]: - if not isinstance(xb.dtype, ImageDType) or (r:=ctx.get(xb)) is None or r.op is not Ops.STORE or (to_cast:=r.src[2]).op in GroupOp.Meta: return None +def fold_img_cast(ctx, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]: + if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None + del ctx[b] return to_cast.view(unwrap(view.st)) do_realize = PatternMatcher([ # always realize meta ops (UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize), - # don't realize image to image casts - (UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, src=(UPat.var("xb"), UPat())),), dtype=dtypes.float).view(name="view"), fold_img_cast), # realize before expand or unsafe pad ops (UPatSrc().view(name="view"), realize_view), + # don't realize image to image casts + (UPatSrc(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast), # realize before COPY or BUFFER_VIEW - (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"), - lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),), + (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view()),)), realize), ]) def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: if isinstance((val:=to_store.arg), UOp): ctx.var_vals.update([val.unbind()]) return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape) +def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: + ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store) + return UOp(Ops.LOAD, base.dtype, (b, st.to_uop())) + break_sched = PatternMatcher([ # consts are always fused and generated (UPatSrc({Ops.CONST, Ops.BIND}), generate_valid), # everything else is a VIEW of BUFFER that either realizes or fuses - (UPatSrc(), lambda ctx,b,to_store,base: realize(ctx.realizes, b, to_store, base) if b in ctx.realizes else None), + (UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else None), ]) @track_rewrites(named=True) @@ -390,13 +394,11 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} buffers: Dict[UOp, Buffer] = {} - big_graph = UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)) - # get realizes - graph_rewrite(big_graph, do_realize, ctx.realizes) + big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes) + # group realizes into kernels store_groups = group_realizes(ctx, ctx.realizes) - # split realizes into small graphs graph_rewrite(big_graph, break_sched, ctx) - # preschedule all realizes + # preschedule realize groups prescheduled: List[ScheduleItem] = [] for store_uops in store_groups: ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5e06220848..f06e8355bb 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -185,7 +185,11 @@ class GroupOp: # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) -def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents) +def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool: + if u.op in GroupOp.UnsafePad: return False + if (len(u.src) == 2 and u.src[0] in edges) or u in visisted: return True + visisted.add(u) + return all(can_pad(x.base, edges, visisted) for x in u.src) END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}