defer realize folding to kernel splitting [pr] (#7849)

* defer realize folding to schedule breaking [pr]

* this is init

* p2

* need to lookup edges

* refactor image cast folding [pr]

* Ops.LOAD diff

* image works

* refactor can_pad

* fix fold_img_cast
This commit is contained in:
qazal
2024-11-23 01:29:14 -05:00
committed by GitHub
parent 144e9f00df
commit 5b2c03e865
3 changed files with 27 additions and 21 deletions

View File

@@ -441,7 +441,7 @@ class Kernel:
check(not self.vars, "does not work with symbolic shape")
check(axis < self.first_upcast, "cannot pad upcasted")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, set()), f"cannot pad {r}")
padded = False
for i,st in enumerate(self.sts):
if (s:=st.shape[axis]) == 1: continue # reduced

View File

@@ -339,47 +339,51 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO
# **** Schedule creation and BFS toposort
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None:
ctx[b] = to_store
return None
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
if to_store.op in {Ops.CONST, Ops.BIND}: return None
base_shape = unwrap(base.st).shape
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
return None if can_pad(base, ctx, set()) else realize(ctx, b, to_store, base)
# early realize before expand
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
return None if (all(v.mask is None for v in st.views) or can_pad(base, ctx, set())) else realize(ctx, b, to_store, base)
def fold_img_cast(ctx, xb:UOp, view:UOp, **kwargs) -> Optional[UOp]:
if not isinstance(xb.dtype, ImageDType) or (r:=ctx.get(xb)) is None or r.op is not Ops.STORE or (to_cast:=r.src[2]).op in GroupOp.Meta: return None
def fold_img_cast(ctx, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None
del ctx[b]
return to_cast.view(unwrap(view.st))
do_realize = PatternMatcher([
# always realize meta ops
(UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
# don't realize image to image casts
(UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, src=(UPat.var("xb"), UPat())),), dtype=dtypes.float).view(name="view"), fold_img_cast),
# realize before expand or unsafe pad ops
(UPatSrc().view(name="view"), realize_view),
# don't realize image to image casts
(UPatSrc(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"),
lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),),
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view()),)), realize),
])
def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
if isinstance((val:=to_store.arg), UOp): ctx.var_vals.update([val.unbind()])
return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape)
def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
break_sched = PatternMatcher([
# consts are always fused and generated
(UPatSrc({Ops.CONST, Ops.BIND}), generate_valid),
# everything else is a VIEW of BUFFER that either realizes or fuses
(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx.realizes, b, to_store, base) if b in ctx.realizes else None),
(UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else None),
])
@track_rewrites(named=True)
@@ -390,13 +394,11 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
buffers: Dict[UOp, Buffer] = {}
big_graph = UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))
# get realizes
graph_rewrite(big_graph, do_realize, ctx.realizes)
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes)
# group realizes into kernels
store_groups = group_realizes(ctx, ctx.realizes)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, ctx)
# preschedule all realizes
# preschedule realize groups
prescheduled: List[ScheduleItem] = []
for store_uops in store_groups:
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx)

View File

@@ -185,7 +185,11 @@ class GroupOp:
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool:
if u.op in GroupOp.UnsafePad: return False
if (len(u.src) == 2 and u.src[0] in edges) or u in visisted: return True
visisted.add(u)
return all(can_pad(x.base, edges, visisted) for x in u.src)
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}