diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index dc8c6ba713..a963fe009f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -43,17 +43,17 @@ def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2 @dataclass(frozen=True) class ScheduleContext: - ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata + lazybufs: Dict[UOp, LazyBuffer] = field(default_factory=dict) # this maps BUFFER uops to Metadata var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule realizes: Dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) -def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazybufs:Dict[Buffer, LazyBuffer], cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: - cache[buf] = ret = to_uop(buf.base, ctx, buffers, lazybufs, cache).view(buf.st) + cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st) return ret # make things that can't be images not images if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or @@ -68,18 +68,17 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazyb buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer op = None elif buf.op is Ops.ASSIGN: - target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs] + target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs] ctx.assigns.add(ubuf:=target.buf_uop) op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg) else: buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer - op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs), + op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg) cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st) if op is not None: - lazybufs[buf.buffer] = buf + ctx.lazybufs[ubuf] = buf ctx.allbufs[ubuf] = ret - if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata for x in op.src: if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None return ret @@ -160,9 +159,9 @@ view_right = merge_views+PatternMatcher([ @dataclass(frozen=True) class ScheduleItemContext: - var_vals: Dict[Variable, int] + lazybufs: Dict[UOp, LazyBuffer] assigned: Set[UOp] - ubuf_metadata: Dict[UOp, Metadata] + var_vals: Dict[Variable, int] sinked: Dict[UOp, UOp] sts: Set[ShapeTracker] = field(default_factory=set) bufs: List[UOp] = field(default_factory=list) @@ -194,7 +193,7 @@ to_si = PatternMatcher([ # ** fusion def fuse_src(ctx:ScheduleItemContext, b:UOp, to_store:UOp, base:UOp) -> UOp: - if (metadata:=ctx.ubuf_metadata.get(b)) is not None: ctx.metadata.add(metadata) + if (metadata:=ctx.lazybufs[b].metadata) is not None: ctx.metadata.add(metadata) return to_store lazy = PatternMatcher([ @@ -206,8 +205,8 @@ lazy = PatternMatcher([ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),]) def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]: - si_ctx = ScheduleItemContext(ctx.var_vals, ctx.assigns, ctx.ubuf_metadata, {x.buf_uop:x.src[2] for x in pre.src}, - metadata={mx for x in pre.src if (mx:=ctx.ubuf_metadata.get(x.buf_uop))}) + si_ctx = ScheduleItemContext(ctx.lazybufs, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src}, + metadata={mx for x in pre.src if (mx:=ctx.lazybufs[x.buf_uop].metadata) is not None}) # fuse and fold store -> loads sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, si_ctx) # assert cyclic dependency @@ -386,8 +385,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} buffers: Dict[UOp, Buffer] = {} - lazybufs: Dict[Buffer, LazyBuffer] = {} - big_graph = UOp.sink(*(to_uop(x, ctx, buffers, lazybufs, cache) for x in outs)) + big_graph = UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)) # get realizes graph_rewrite(big_graph, do_realize, ctx.realizes) store_groups = group_realizes(ctx, ctx.realizes) @@ -399,6 +397,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx) prescheduled.append(ScheduleItem(ast, tuple(buffers[u] for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata), frozenset(x.buf_uop for x in ast_ctx.assign_preloads))) + for u in ast_ctx.sinked: del ast_ctx.lazybufs[u].srcs # can only schedule once # do BFS schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list) @@ -418,7 +417,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] schedule: List[ScheduleItem] = [] while queue: schedule.append(si:=queue.popleft()) - for b in si.outputs: del lazybufs[b].srcs # can only schedule once for x in graph[si]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x)