mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
defer realize folding to kernel splitting [pr] (#7849)
* defer realize folding to schedule breaking [pr] * this is init * p2 * need to lookup edges * refactor image cast folding [pr] * Ops.LOAD diff * image works * refactor can_pad * fix fold_img_cast
This commit is contained in:
@@ -441,7 +441,7 @@ class Kernel:
|
||||
check(not self.vars, "does not work with symbolic shape")
|
||||
check(axis < self.first_upcast, "cannot pad upcasted")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, set()), f"cannot pad {r}")
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
if (s:=st.shape[axis]) == 1: continue # reduced
|
||||
|
||||
@@ -339,47 +339,51 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
|
||||
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None:
|
||||
ctx[b] = to_store
|
||||
return None
|
||||
|
||||
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
|
||||
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
|
||||
if to_store.op in {Ops.CONST, Ops.BIND}: return None
|
||||
base_shape = unwrap(base.st).shape
|
||||
st = unwrap(view.st)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
|
||||
return None if can_pad(base, ctx, set()) else realize(ctx, b, to_store, base)
|
||||
# early realize before expand
|
||||
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
|
||||
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(base, ctx, set())) else realize(ctx, b, to_store, base)
|
||||
|
||||
def fold_img_cast(ctx, xb:UOp, view:UOp, **kwargs) -> Optional[UOp]:
|
||||
if not isinstance(xb.dtype, ImageDType) or (r:=ctx.get(xb)) is None or r.op is not Ops.STORE or (to_cast:=r.src[2]).op in GroupOp.Meta: return None
|
||||
def fold_img_cast(ctx, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None
|
||||
del ctx[b]
|
||||
return to_cast.view(unwrap(view.st))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
|
||||
# don't realize image to image casts
|
||||
(UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, src=(UPat.var("xb"), UPat())),), dtype=dtypes.float).view(name="view"), fold_img_cast),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPatSrc().view(name="view"), realize_view),
|
||||
# don't realize image to image casts
|
||||
(UPatSrc(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"),
|
||||
lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),),
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view()),)), realize),
|
||||
])
|
||||
|
||||
def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
if isinstance((val:=to_store.arg), UOp): ctx.var_vals.update([val.unbind()])
|
||||
return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape)
|
||||
|
||||
def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
|
||||
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# consts are always fused and generated
|
||||
(UPatSrc({Ops.CONST, Ops.BIND}), generate_valid),
|
||||
# everything else is a VIEW of BUFFER that either realizes or fuses
|
||||
(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx.realizes, b, to_store, base) if b in ctx.realizes else None),
|
||||
(UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else None),
|
||||
])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
@@ -390,13 +394,11 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
buffers: Dict[UOp, Buffer] = {}
|
||||
big_graph = UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))
|
||||
# get realizes
|
||||
graph_rewrite(big_graph, do_realize, ctx.realizes)
|
||||
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes)
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx, ctx.realizes)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, ctx)
|
||||
# preschedule all realizes
|
||||
# preschedule realize groups
|
||||
prescheduled: List[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx)
|
||||
|
||||
@@ -185,7 +185,11 @@ class GroupOp:
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
|
||||
def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool:
|
||||
if u.op in GroupOp.UnsafePad: return False
|
||||
if (len(u.src) == 2 and u.src[0] in edges) or u in visisted: return True
|
||||
visisted.add(u)
|
||||
return all(can_pad(x.base, edges, visisted) for x in u.src)
|
||||
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user