mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
big graph do_realize cleanup and renames [pr] (#7952)
* scheduler do_realize cleanup and renames [pr] * big graph is the better name * more language * append_kernel -> append_realize
This commit is contained in:
@@ -47,8 +47,6 @@ class ScheduleContext:
|
||||
ops_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
|
||||
children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
class UPatSrc(UPat):
|
||||
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name": "to_store"})), name="base")
|
||||
@functools.lru_cache(None)
|
||||
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
|
||||
|
||||
@@ -324,6 +322,16 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]:
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
# ** ops in the big graph can either be pre-realized or scheduled (fused/realized)
|
||||
|
||||
class UPatRealized(UPat):
|
||||
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),))
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
|
||||
UPat(*args, **{**kwargs,"name":"to_store"})))
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)])
|
||||
|
||||
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
|
||||
@@ -347,20 +355,22 @@ do_realize = PatternMatcher([
|
||||
# always realize sinked ops
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src if is_scheduled(x))),
|
||||
# always realize meta ops
|
||||
(UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPatSrc().view(name="view"), realize_view),
|
||||
(UPatScheduled().view(name="view"), realize_view),
|
||||
# don't realize image to image casts
|
||||
(UPatSrc(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
|
||||
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view()),)), realize),
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
])
|
||||
|
||||
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
|
||||
|
||||
def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
if isinstance((val:=to_store.arg), UOp): ctx.var_vals.update([val.unbind()])
|
||||
return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape)
|
||||
|
||||
def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), append_op(ctx, b, to_store))
|
||||
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
|
||||
|
||||
@@ -370,11 +380,11 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# consts are always fused and generated
|
||||
(UPatSrc({Ops.CONST, Ops.BIND}), generate_valid),
|
||||
(UPatScheduled({Ops.CONST, Ops.BIND}), generate_valid),
|
||||
# everything else is a VIEW of BUFFER that either realizes or fuses
|
||||
(UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
|
||||
# just load realized buffers
|
||||
(UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, v.dtype, (b, v.st.to_uop()))),
|
||||
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))),
|
||||
])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
|
||||
Reference in New Issue
Block a user