From f0442718988f55430e09cd7984d7a61b2123dc1a Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 29 Nov 2024 01:58:45 -0500 Subject: [PATCH] big graph do_realize cleanup and renames [pr] (#7952) * scheduler do_realize cleanup and renames [pr] * big graph is the better name * more language * append_kernel -> append_realize --- tinygrad/engine/schedule.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c263793d02..76cb9615b8 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -47,8 +47,6 @@ class ScheduleContext: ops_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) -class UPatSrc(UPat): - def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name": "to_store"})), name="base") @functools.lru_cache(None) def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 @@ -324,6 +322,16 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]: # **** Schedule creation and BFS toposort +# ** ops in the big graph can either be pre-realized or scheduled (fused/realized) + +class UPatRealized(UPat): + def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),)) +class UPatScheduled(UPat): + def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), + UPat(*args, **{**kwargs,"name":"to_store"}))) + +# ** this decides which ops get realized + def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)]) def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None: @@ -347,20 +355,22 @@ do_realize = PatternMatcher([ # always realize sinked ops (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src if is_scheduled(x))), # always realize meta ops - (UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize), + (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize), # realize before expand or unsafe pad ops - (UPatSrc().view(name="view"), realize_view), + (UPatScheduled().view(name="view"), realize_view), # don't realize image to image casts - (UPatSrc(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast), + (UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast), # realize before COPY or BUFFER_VIEW - (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view()),)), realize), + (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), ]) +# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs + def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: if isinstance((val:=to_store.arg), UOp): ctx.var_vals.update([val.unbind()]) return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape) -def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: +def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), append_op(ctx, b, to_store)) return UOp(Ops.LOAD, base.dtype, (b, st.to_uop())) @@ -370,11 +380,11 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp: break_sched = PatternMatcher([ # consts are always fused and generated - (UPatSrc({Ops.CONST, Ops.BIND}), generate_valid), + (UPatScheduled({Ops.CONST, Ops.BIND}), generate_valid), # everything else is a VIEW of BUFFER that either realizes or fuses - (UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)), + (UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)), # just load realized buffers - (UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, v.dtype, (b, v.st.to_uop()))), + (UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))), ]) @track_rewrites(named=True)