flatten fusion upats [pr] (#7732)

This commit is contained in:
qazal
2024-11-16 15:26:19 +02:00
committed by GitHub
parent ec8c5598f6
commit f3f95ab9d9

View File

@@ -37,7 +37,8 @@ class ScheduleItem:
# **** small wrapper for LazyBuffer -> UOp
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
def UPatSrc(*args, **kwargs):
return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat(*args, **{**kwargs, "name":"to_store"})), name="base")
@functools.lru_cache(None)
def is_scheduled(u:UOp): return u.op is Ops.LOAD and len(u.src) == 3
@@ -198,7 +199,7 @@ to_si = PatternMatcher([
# ** fusion
lazy = PatternMatcher([
(UPatLoadStore(UPat.var("v")), lambda ctx,v,**kwargs: v),
(UPatSrc(), lambda ctx,to_store,**kwargs: to_store),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
])
@@ -332,34 +333,34 @@ def group_realizes(children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp,
# **** Schedule creation and BFS toposort
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
ctx[b] = store
return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop()))
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, **kwargs) -> Optional[UOp]:
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
base_shape = unwrap(base.st).shape
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
return None if can_pad(base) else realize(ctx, **kwargs).view(st)
return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
# early realize before expand
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, **kwargs).view(st)
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, **kwargs).view(st)
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
do_realize = PatternMatcher([
# always realize meta ops
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize),
(UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize),
# don't realize image to image casts
(UPatLoadStore(UPat(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float)).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
(UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
# realize before expand or unsafe pad ops
(UPatLoadStore(UPat.var("base")).view(name="view"), realize_view),
(UPatSrc().view(name="view"), realize_view),
# realize before COPY or BUFFER_VIEW
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="view"))), name="root"),
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatSrc(), UPatSrc().view(name="view"))), name="root"),
lambda ctx,root,u,view=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st))),),
])
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
break_sched = PatternMatcher([(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx, b, to_store, base) if b in ctx else None),])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: