big graph first [pr] (#7443)

* big graph first [pr]

* move things
This commit is contained in:
qazal
2024-10-31 14:10:11 +02:00
committed by GitHub
parent 38b1790575
commit c5a50465d1

View File

@@ -40,9 +40,9 @@ class ScheduleItem:
@dataclass(frozen=True)
class ScheduleContext:
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict)
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict)
var_vals: Dict[Variable, int] = field(default_factory=dict)
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
@@ -237,10 +237,12 @@ break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b,
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) == 0: return [], {}
store_groups, lazybufs_to_realize, assigns = get_realizes(outs)
# create the big graph
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs))
# get realizes
store_groups, lazybufs_to_realize, assigns = get_realizes(outs)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, realizes:={(u:=ctx.buf_uops[b]):u for b in lazybufs_to_realize})
assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}