mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
@@ -40,9 +40,9 @@ class ScheduleItem:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict)
|
||||
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict)
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict)
|
||||
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops
|
||||
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
|
||||
|
||||
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
@@ -237,10 +237,12 @@ break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b,
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) == 0: return [], {}
|
||||
store_groups, lazybufs_to_realize, assigns = get_realizes(outs)
|
||||
# create the big graph
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs))
|
||||
# get realizes
|
||||
store_groups, lazybufs_to_realize, assigns = get_realizes(outs)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, realizes:={(u:=ctx.buf_uops[b]):u for b in lazybufs_to_realize})
|
||||
assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}
|
||||
|
||||
Reference in New Issue
Block a user