diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 59965122cb..a1c0106a57 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -48,4 +48,4 @@ def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([si.bufs for si in schedule], noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs}) - return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule] + return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule] diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2e2b49694e..82a1a5095f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -25,6 +25,7 @@ class ScheduleItem: ast: UOp bufs: Tuple[Buffer, ...] metadata: Tuple[Metadata, ...] + assign_preloads: Tuple[UOp, ...] @property def outputs(self) -> Tuple[Buffer, ...]: """Read/write or write only buffers in the schedule.""" @@ -36,17 +37,6 @@ class ScheduleItem: @functools.cached_property def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,) -@dataclass(frozen=True) -class LBScheduleItem: - ast: UOp - bufs: Tuple[LazyBuffer, ...] - ubufs: Tuple[UOp, ...] - metadata: Tuple[Metadata, ...] - @property - def outputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1] - @property - def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:] - # *** UOp with VIEW (movementops) rewriting to UOp we can index *** # ** helpers for doing movementops on uops @@ -132,8 +122,10 @@ view_right = merge_views+PatternMatcher([ @dataclass(frozen=True) class ScheduleItemContext: var_vals: Dict[Variable, int] + assigned: Set[UOp] sts: Set[ShapeTracker] = field(default_factory=set) bufs: List[UOp] = field(default_factory=list) + assign_preloads: List[UOp] = field(default_factory=list) def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]: if (st:=unwrap(x.st)) in ctx.sts: return None @@ -145,11 +137,15 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]: def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: ctx.bufs.append(x) return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1) +append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf)]) + +def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: + if b in ctx.assigned: ctx.assign_preloads.append(b) + return x.replace(op=UOps.LOAD) to_si = PatternMatcher([ - (UPat(UOps.BUFFER, name="x"), _append_buf), (UPat(UOps.VIEW, name="x"), _append_st_vars), - (UPat(UOps.PRELOAD, name="x"), lambda _,x: x.replace(op=UOps.LOAD)), + (UPat(UOps.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload), (UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda _,x: x), (UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda _,x: x), ]) @@ -157,8 +153,8 @@ to_si = PatternMatcher([ PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = [] def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp: sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right) - ret = graph_rewrite(sink, to_si, ctx) - PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals), ret)) + ret = graph_rewrite(graph_rewrite(sink, to_si, ctx), append_bufs, ctx) + PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals, ctx.assigned), ret)) return ret if getenv("RUN_PROCESS_REPLAY"): @@ -168,18 +164,16 @@ if getenv("RUN_PROCESS_REPLAY"): # *** List[LazyBuffer] lowering to ScheduleItem *** -def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata], - cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata], cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: - cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, metadata, cache).view(buf.st) + cache[buf] = ret = to_uop(buf.base, outputs, buf_uops, metadata, cache).view(buf.st) return ret if buf.op is MetaOps.CONST: return buf_uops[buf.buffer] dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs: - if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf) return UOp(UOps.PRELOAD if buf.is_realized() else UOps.LOAD, dtype, (ubuf, buf.st.to_uop())) - src = tuple(to_uop(x, outputs, inputs, buf_uops, metadata, cache) for x in buf.srcs) + src = tuple(to_uop(x, outputs, buf_uops, metadata, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg) @@ -191,26 +185,25 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu if buf.metadata is not None: metadata[ret] = buf.metadata return ret -def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], uop_bufs:Dict[UOp, Buffer], - var_vals:Dict[Variable, int]) -> LBScheduleItem: +def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], uop_bufs:Dict[UOp, Buffer], assigned:Set[UOp], + var_vals:Dict[Variable, int]) -> ScheduleItem: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" cache: Dict[LazyBuffer, UOp] = {} - inputs: List[LazyBuffer] = [] metadata: Dict[UOp, Metadata] = {} sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(), - to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs)) + to_uop(out, outs, buf_uops, metadata, cache)) for out in outs)) # assert cyclic dependency - for b,reads in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD, UOps.LOAD}), key=lambda x:x.src[0]): + for b,reads in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD, UOps.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]): if not all_same([x.op for x in reads]): raise RuntimeError(f"cycle detected in kernel.\nhelp: consider using .contiguous() to load the pre-assign version of {uop_bufs[b]}.") - sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals)) + sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals, assigned)) # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0: if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets): raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - return LBScheduleItem(sink, tuple(outs+inputs), tuple(ctx.bufs), tuple(dedup(metadata.values()))) + return ScheduleItem(sink, tuple(b for u in ctx.bufs if (b:=uop_bufs[u]).size != 0), tuple(dedup(metadata.values())), tuple(ctx.assign_preloads)) # *** DAG creation: decide which LazyBuffers should realize *** @@ -372,7 +365,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] buf_uops: Dict[Buffer, UOp] = {} uop_bufs: Dict[UOp, Buffer] = {} var_vals: Dict[Variable, int] = {} - assign_targets: Dict[LazyBuffer, LazyBuffer] = {} + assigned: Set[UOp] = set() lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {} for buf in realizes: if buf.realized is None and buf.op is not MetaOps.CONST: @@ -399,24 +392,24 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] buf_uops[buf.buffer] = uop uop_bufs[uop] = buf.buffer if buf.realized is None: - if buf.op is MetaOps.ASSIGN: assign_targets[buf.srcs[0]] = buf + if buf.op is MetaOps.ASSIGN: assigned.add(buf_uops[buf.buffer]) if buf.op is not MetaOps.CONST:output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer]) # preschedule all buffers in realizes - prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, uop_bufs, + prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, uop_bufs, assigned, var_vals) for outs in output_groups.values()] schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs} - graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list) - in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int) + graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list) + in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int) for lsi in prescheduled: # realize outputs before a parent is assigned to - parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets) + parents_assigns = dedup(xsi for x in lsi.assign_preloads if (xsi:=schedule_targets.get(uop_bufs[x])) and xsi is not lsi) for assign in parents_assigns: graph[lsi].append(assign) in_degree[assign] += 1 # realize outputs after all parents are realized - scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None) + scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns) for x in scheduled_parents: graph[x].append(lsi) in_degree[lsi] += 1 @@ -424,13 +417,12 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] queue = deque(lsi for lsi in prescheduled if in_degree[lsi] == 0) schedule: List[ScheduleItem] = [] while queue: - lsi = queue.popleft() - schedule.append(si:=ScheduleItem(lsi.ast, tuple(b for u in lsi.ubufs if (b:=uop_bufs[u]).size != 0), lsi.metadata)) + schedule.append(si:=queue.popleft()) for b in si.outputs: del lazybufs_to_realize[b].srcs # can only schedule once if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m: if DEBUG >= 3: print(si) raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.") - for x in graph[lsi]: + for x in graph[si]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x)