diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 0e9557fdf2..49aab9eefa 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -39,11 +39,12 @@ class ScheduleItem: @dataclass(frozen=True) class ScheduleContext: - lazybufs: Dict[UOp, LazyBuffer] = field(default_factory=dict) # this maps BUFFER uops to Metadata + lazybufs: Dict[UOp, LazyBuffer] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the underlying lazybuffer var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op + ops_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) class UPatSrc(UPat): @@ -144,6 +145,7 @@ view_right = merge_views+PatternMatcher([ @dataclass(frozen=True) class ScheduleItemContext: lazybufs: Dict[UOp, LazyBuffer] + ops_metadata: Dict[UOp, Metadata] assigns: Set[UOp] var_vals: Dict[Variable, int] sinked: Dict[UOp, UOp] @@ -177,12 +179,8 @@ to_si = PatternMatcher([ # ** fusion -def fuse_src(ctx:ScheduleItemContext, b:UOp, to_store:UOp, base:UOp) -> UOp: - if (lbuf:=ctx.lazybufs.get(b)) is not None and (metadata:=lbuf.metadata) is not None: ctx.metadata.add(metadata) - return to_store - lazy = PatternMatcher([ - (UPatSrc(), fuse_src), + (UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None), (UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, v.dtype, (b, v.st.to_uop()))), (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), ]) @@ -193,7 +191,7 @@ append_load = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambd if b in ctx.assigns else None)]) def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]: - si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src}, + si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src}, metadata={l.metadata for x in pre.src if (l:=ctx.lazybufs.get(x.buf_uop)) is not None and l.metadata is not None}) # fuse and fold store -> loads ops_folding = lazy if len(si_ctx.sinked) == 1 else lazy+multioutput @@ -360,14 +358,18 @@ def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape) def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: - ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store) + ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), append_op(ctx, b, to_store)) return UOp(Ops.LOAD, base.dtype, (b, st.to_uop())) +def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp: + if (m:=ctx.lazybufs[b].metadata) is not None: ctx.ops_metadata[to_store] = m + return to_store + break_sched = PatternMatcher([ # consts are always fused and generated (UPatSrc({Ops.CONST, Ops.BIND}), generate_valid), # everything else is a VIEW of BUFFER that either realizes or fuses - (UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else None), + (UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)), ]) @track_rewrites(named=True)