mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
flatten fusion upats [pr] (#7732)
This commit is contained in:
@@ -37,7 +37,8 @@ class ScheduleItem:
|
||||
|
||||
# **** small wrapper for LazyBuffer -> UOp
|
||||
|
||||
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
|
||||
def UPatSrc(*args, **kwargs):
|
||||
return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat(*args, **{**kwargs, "name":"to_store"})), name="base")
|
||||
@functools.lru_cache(None)
|
||||
def is_scheduled(u:UOp): return u.op is Ops.LOAD and len(u.src) == 3
|
||||
|
||||
@@ -198,7 +199,7 @@ to_si = PatternMatcher([
|
||||
# ** fusion
|
||||
|
||||
lazy = PatternMatcher([
|
||||
(UPatLoadStore(UPat.var("v")), lambda ctx,v,**kwargs: v),
|
||||
(UPatSrc(), lambda ctx,to_store,**kwargs: to_store),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
])
|
||||
|
||||
@@ -332,34 +333,34 @@ def group_realizes(children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp,
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
|
||||
ctx[b] = store
|
||||
return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop()))
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
|
||||
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
|
||||
|
||||
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, **kwargs) -> Optional[UOp]:
|
||||
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
|
||||
base_shape = unwrap(base.st).shape
|
||||
st = unwrap(view.st)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(base) else realize(ctx, **kwargs).view(st)
|
||||
return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
|
||||
# early realize before expand
|
||||
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, **kwargs).view(st)
|
||||
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, **kwargs).view(st)
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize),
|
||||
(UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize),
|
||||
# don't realize image to image casts
|
||||
(UPatLoadStore(UPat(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float)).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
|
||||
(UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
|
||||
if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPatLoadStore(UPat.var("base")).view(name="view"), realize_view),
|
||||
(UPatSrc().view(name="view"), realize_view),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="view"))), name="root"),
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatSrc(), UPatSrc().view(name="view"))), name="root"),
|
||||
lambda ctx,root,u,view=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st))),),
|
||||
])
|
||||
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
|
||||
break_sched = PatternMatcher([(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx, b, to_store, base) if b in ctx else None),])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
|
||||
Reference in New Issue
Block a user