big graph do_realize cleanup and renames [pr] (#7952)

* scheduler do_realize cleanup and renames [pr]

* big graph is the better name

* more language

* append_kernel -> append_realize
This commit is contained in:
qazal
2024-11-29 01:58:45 -05:00
committed by GitHub
parent 6e47dc8921
commit f044271898

View File

@@ -47,8 +47,6 @@ class ScheduleContext:
ops_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
class UPatSrc(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name": "to_store"})), name="base")
@functools.lru_cache(None)
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
@@ -324,6 +322,16 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]:
# **** Schedule creation and BFS toposort
# ** ops in the big graph can either be pre-realized or scheduled (fused/realized)
class UPatRealized(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),))
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
UPat(*args, **{**kwargs,"name":"to_store"})))
# ** this decides which ops get realized
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)])
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
@@ -347,20 +355,22 @@ do_realize = PatternMatcher([
# always realize sinked ops
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src if is_scheduled(x))),
# always realize meta ops
(UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
# realize before expand or unsafe pad ops
(UPatSrc().view(name="view"), realize_view),
(UPatScheduled().view(name="view"), realize_view),
# don't realize image to image casts
(UPatSrc(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view()),)), realize),
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
])
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
if isinstance((val:=to_store.arg), UOp): ctx.var_vals.update([val.unbind()])
return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape)
def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), append_op(ctx, b, to_store))
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
@@ -370,11 +380,11 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
break_sched = PatternMatcher([
# consts are always fused and generated
(UPatSrc({Ops.CONST, Ops.BIND}), generate_valid),
(UPatScheduled({Ops.CONST, Ops.BIND}), generate_valid),
# everything else is a VIEW of BUFFER that either realizes or fuses
(UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
# just load realized buffers
(UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, v.dtype, (b, v.st.to_uop()))),
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))),
])
@track_rewrites(named=True)