mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
ops metadata map try 2, early fuse [pr] (#7893)
* make this return early * delete that * ops metadata map try 2, early fuse [pr]
This commit is contained in:
@@ -39,11 +39,12 @@ class ScheduleItem:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
lazybufs: Dict[UOp, LazyBuffer] = field(default_factory=dict) # this maps BUFFER uops to Metadata
|
||||
lazybufs: Dict[UOp, LazyBuffer] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the underlying lazybuffer
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
|
||||
assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
||||
realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
ops_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
|
||||
children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
class UPatSrc(UPat):
|
||||
@@ -144,6 +145,7 @@ view_right = merge_views+PatternMatcher([
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
lazybufs: Dict[UOp, LazyBuffer]
|
||||
ops_metadata: Dict[UOp, Metadata]
|
||||
assigns: Set[UOp]
|
||||
var_vals: Dict[Variable, int]
|
||||
sinked: Dict[UOp, UOp]
|
||||
@@ -177,12 +179,8 @@ to_si = PatternMatcher([
|
||||
|
||||
# ** fusion
|
||||
|
||||
def fuse_src(ctx:ScheduleItemContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
if (lbuf:=ctx.lazybufs.get(b)) is not None and (metadata:=lbuf.metadata) is not None: ctx.metadata.add(metadata)
|
||||
return to_store
|
||||
|
||||
lazy = PatternMatcher([
|
||||
(UPatSrc(), fuse_src),
|
||||
(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None),
|
||||
(UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, v.dtype, (b, v.st.to_uop()))),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
])
|
||||
@@ -193,7 +191,7 @@ append_load = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambd
|
||||
if b in ctx.assigns else None)])
|
||||
|
||||
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
|
||||
si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src},
|
||||
si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src},
|
||||
metadata={l.metadata for x in pre.src if (l:=ctx.lazybufs.get(x.buf_uop)) is not None and l.metadata is not None})
|
||||
# fuse and fold store -> loads
|
||||
ops_folding = lazy if len(si_ctx.sinked) == 1 else lazy+multioutput
|
||||
@@ -360,14 +358,18 @@ def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape)
|
||||
|
||||
def append_kernel(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), append_op(ctx, b, to_store))
|
||||
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
|
||||
|
||||
def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
|
||||
if (m:=ctx.lazybufs[b].metadata) is not None: ctx.ops_metadata[to_store] = m
|
||||
return to_store
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# consts are always fused and generated
|
||||
(UPatSrc({Ops.CONST, Ops.BIND}), generate_valid),
|
||||
# everything else is a VIEW of BUFFER that either realizes or fuses
|
||||
(UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else None),
|
||||
(UPatSrc(), lambda ctx,b,to_store,base: append_kernel(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
|
||||
])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
|
||||
Reference in New Issue
Block a user