diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7756087b7b..1edb3dec95 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -70,7 +70,7 @@ jobs: source venv/bin/activate pip install $GITHUB_WORKSPACE cp $GITHUB_WORKSPACE/examples/beautiful_mnist.py . - BS=2 STEPS=10 python beautiful_mnist.py + BS=2 STEPS=10 MAX_BUFFER_SIZE=0 python beautiful_mnist.py - name: Test Docs Build run: python -m mkdocs build --strict - name: Test Docs @@ -141,7 +141,7 @@ jobs: sudo apt update || true sudo apt install -y --no-install-recommends ninja-build - name: Test beautiful_mnist in torch with TINY_BACKEND - run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py + run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 MAX_BUFFER_SIZE=0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py - name: Test some torch tests (expect failure) run: python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true diff --git a/extra/export_model.py b/extra/export_model.py index 1d19976da1..e3b4a9a7f5 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -13,12 +13,20 @@ from collections import OrderedDict EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "CL"] def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]: + # memory-planned subbuffers can have multiple Buffer objects for the same memory region + canon, _seen = {}, {} + for ji in run.jit_cache: + for b in ji.bufs: + if b is not None: canon[id(b)] = _seen.setdefault((id(b.base._buf), b.offset, b.size, b.dtype), b) + special_names = {id(canon[k]): v for k, v in special_names.items() if k in canon} + functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0 for ji in run.jit_cache: fxn: ProgramSpec = ji.prg.p functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same cargs = [] for i,arg in enumerate(ji.bufs): + arg = canon[id(arg)] key = id(arg) if key not in bufs: if key in special_names: diff --git a/test/null/test_memory_planner.py b/test/null/test_memory_planner.py index b60a9f0f69..d83869c8aa 100644 --- a/test/null/test_memory_planner.py +++ b/test/null/test_memory_planner.py @@ -1,42 +1,73 @@ import unittest from tinygrad import dtypes -from tinygrad.device import Buffer -from tinygrad.engine.memory import _internal_memory_planner +from tinygrad.uop.ops import UOp, Ops +from tinygrad.engine.memory import memory_plan_rewrite global_map = {} +held_bufs: set[UOp] = set() def b(i, base=None, offset=0, pin=False, size=16): global global_map if i in global_map: return global_map[i] - global_map[i] = Buffer("NULL", size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset) - if pin: global_map[i].ref(1) + if base is not None: + global_map[i] = global_map[base] + return global_map[i] + global_map[i] = UOp.new_buffer("NULL", size, dtypes.int8) + if pin: held_bufs.add(global_map[i]) return global_map[i] -def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]], copies:list[tuple[Buffer, Buffer]]|None=None): - assigned = _internal_memory_planner(buffers, copies=copies) +def _make_linear(buffer_lists, copies=None): + copy_pairs = {frozenset((id(dst), id(src))) for dst, src in copies} if copies else set() + calls = [] + for bufs in buffer_lists: + is_copy = len(bufs) == 2 and frozenset((id(bufs[0]), id(bufs[1]))) in copy_pairs + calls.append(UOp(Ops.CALL, dtypes.void, (UOp(Ops.COPY if is_copy else Ops.SINK), *bufs))) + return UOp(Ops.LINEAR, src=tuple(calls)) - taken_parts = set() +def _get_arena(buf, linear, result): + for orig_si, new_si in zip(linear.src, result.src): + for orig, new in zip(orig_si.src[1:], new_si.src[1:]): + if orig is buf and new.op is Ops.BUFFER_VIEW: return new.src[0] + return None + +def check_assign(buffer_lists, copies=None): + linear = _make_linear(buffer_lists, copies) + result = memory_plan_rewrite(linear, held_bufs) + + # build mapping: original buf -> (arena, offset_bytes, nbytes) from the result + replace_map: dict[int, tuple[UOp, int, int]] = {} + for orig_si, new_si in zip(linear.src, result.src): + for orig, new in zip(orig_si.src[1:], new_si.src[1:]): + if new.op is Ops.BUFFER_VIEW and id(orig) not in replace_map: + replace_map[id(orig)] = (new.src[0], new.arg[1] * new.dtype.itemsize, new.arg[0] * new.dtype.itemsize) + + # verify pinned buffers are not planned + for buf in held_bufs: + assert id(buf) not in replace_map, "pinned buffer was planned" + + # compute lifetimes first_appearance, last_appearance = {}, {} - for i,u in enumerate(buffers): - for buf in u: - if buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0: continue - if buf.base not in first_appearance: first_appearance[buf.base] = i - last_appearance[buf.base] = i + for i, bufs in enumerate(buffer_lists): + for buf in bufs: + if buf in held_bufs: continue + if id(buf) not in first_appearance: first_appearance[id(buf)] = i + last_appearance[id(buf)] = i - for i,u in enumerate(buffers): - for buf in u: - if buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0: continue - cur, base = assigned.get(buf, buf), assigned.get(buf.base, buf.base) - if buf._base is not None: - assert cur.base == base.base and cur.offset == buf.offset + base.offset, f"failed: {buf} {cur} {base} {buf.offset} {base.offset}" - else: - for part in taken_parts: - assert buf.base == part[3] or part[0] != cur.base or part[1] + part[2] <= cur.offset or part[1] >= cur.offset + buf.nbytes - if first_appearance[buf.base] == i: taken_parts.add((cur.base, cur.offset, buf.nbytes, buf.base)) - if last_appearance[buf.base] == i: taken_parts.remove((cur.base, cur.offset, buf.nbytes, buf.base)) + # verify non-overlapping: no two live buffers share the same arena region + taken_parts: set[tuple[int, int, int, int]] = set() # (id(arena), offset, nbytes, id(buf)) + for i, bufs in enumerate(buffer_lists): + for buf in bufs: + if buf in held_bufs or id(buf) not in replace_map: continue + arena, off, nb = replace_map[id(buf)] + for part in taken_parts: + assert id(buf) == part[3] or part[0] != id(arena) or part[1] + part[2] <= off or part[1] >= off + nb, \ + f"overlap at step {i}: [{off}, {off+nb}) conflicts with [{part[1]}, {part[1]+part[2]})" + if first_appearance.get(id(buf)) == i: taken_parts.add((id(arena), off, nb, id(buf))) + if last_appearance.get(id(buf)) == i: taken_parts.discard((id(arena), off, nb, id(buf))) class TestMemoryPlanner(unittest.TestCase): def setUp(self): global global_map + held_bufs.clear() global_map = {} def test_simple_buffer(self): @@ -140,9 +171,11 @@ class TestMemoryPlanner(unittest.TestCase): [b(1), b(2)], [b(3), b(2)], ] - assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))]) - r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2)) - assert r1.base != r2.base + linear = _make_linear(bs, copies=[(b(1), b(0))]) + result = memory_plan_rewrite(linear) + r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result) + assert r1_arena is not None and r2_arena is not None + assert r1_arena is not r2_arena def test_copy_bufs_reuse_among_copies(self): bs = [ @@ -150,9 +183,11 @@ class TestMemoryPlanner(unittest.TestCase): [b(2), b(1)], [b(3), b(2)], ] - assigned = _internal_memory_planner(bs, copies=[(b(1), b(0)), (b(2), b(1))]) - r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2)) - assert r1.base == r2.base + linear = _make_linear(bs, copies=[(b(1), b(0)), (b(2), b(1))]) + result = memory_plan_rewrite(linear) + r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result) + assert r1_arena is not None and r2_arena is not None + assert r1_arena is r2_arena def test_compute_bufs_reuse_among_compute(self): bs = [ @@ -161,9 +196,11 @@ class TestMemoryPlanner(unittest.TestCase): [b(3), b(2)], [b(4), b(3)], ] - assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))]) - r2, r3 = assigned.get(b(2), b(2)), assigned.get(b(3), b(3)) - assert r2.base == r3.base + linear = _make_linear(bs, copies=[(b(1), b(0))]) + result = memory_plan_rewrite(linear) + r2_arena, r3_arena = _get_arena(b(2), linear, result), _get_arena(b(3), linear, result) + assert r2_arena is not None and r3_arena is not None + assert r2_arena is r3_arena def test_copy_and_compute_no_cross_reuse(self): bs = [ @@ -171,9 +208,11 @@ class TestMemoryPlanner(unittest.TestCase): [b(2), b(1)], [b(3), b(2)], ] - assigned = _internal_memory_planner(bs, copies=[(b(2), b(1))]) - r0, r2 = assigned.get(b(0), b(0)), assigned.get(b(2), b(2)) - assert r0.base != r2.base + linear = _make_linear(bs, copies=[(b(2), b(1))]) + result = memory_plan_rewrite(linear) + r0_arena, r2_arena = _get_arena(b(0), linear, result), _get_arena(b(2), linear, result) + assert r0_arena is not None and r2_arena is not None + assert r0_arena is not r2_arena def test_multiple_copy_bufs_with_offsets(self): bs = [ diff --git a/test/null/test_real_world.py b/test/null/test_real_world.py index 9cfbfdb1de..61b971d883 100644 --- a/test/null/test_real_world.py +++ b/test/null/test_real_world.py @@ -74,7 +74,10 @@ class TestRealWorld(unittest.TestCase): def test(t, t2): for l in model: t = l(t, t2) return t.realize() - helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.0002, 37) + + # TODO: support _offset on CL to get mem down to 0.0002 + exp_mem = 0.00037 if Device.DEFAULT == "CL" else 0.0002 + helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, exp_mem, 37) @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16") def test_llama(self): diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index ad1ece845b..0712e82e5f 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -6,11 +6,11 @@ from tinygrad.device import Buffer, Compiled, Device, MultiBuffer from tinygrad.dtype import DType from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, buffers from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates -from tinygrad.engine.memory import _internal_memory_planner, _collect_bufs +from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs from tinygrad.engine.schedule import linear_to_schedule from tinygrad.nn.state import get_parameters from tinygrad.schedule.rangeify import mop_cleanup -from dataclasses import dataclass, replace +from dataclasses import dataclass def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]: kept, onetime = [], [] @@ -212,11 +212,13 @@ class CapturedJit(Generic[ReturnType]): def free_intermediates(self): depends: set[Buffer|None] = set([None]) update_depends(depends, self.jit_cache) - for b in depends: - if b is not None: - if b.is_allocated(): b.deallocate() - if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate() - self.__post_init__() # reset the graph state + arenas = {b._base for b in depends if b is not None and b._base is not None} + to_free = {b for b in depends if b is not None} | {b for ei in self.jit_cache for b in ei.bufs if b is not None and b._base in arenas} + for b in to_free: + if hasattr(b, '_buf'): b.deallocate() + for a in arenas: + if a.allocated_views == 0 and a.is_allocated(): a.deallocate() + self.__post_init__() # jit exec def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType: @@ -322,7 +324,7 @@ class TinyJit(Generic[ReturnType]): _check_no_non_tensor_return(ret) if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs") - # combine all captured linears into one and convert to ExecItems + # combine all captured linears into one, memory plan, and convert to ExecItems big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears]))) del self._linears @@ -334,7 +336,9 @@ class TinyJit(Generic[ReturnType]): ei.run(var_vals, jit=True) del onetime_linear - with Context(BEAM=getenv("JITBEAM", BEAM.value)): jit_cache = [ei.lower() for ei in linear_to_schedule(big_linear)] + held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER} + with Context(BEAM=getenv("JITBEAM", BEAM.value)): + jit_cache = [ei.lower() for ei in linear_to_schedule(memory_plan_rewrite(big_linear, held_bufs))] del big_linear # track inputs that are views of buffers @@ -346,11 +350,6 @@ class TinyJit(Generic[ReturnType]): input_buffers.append(b) extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype)) - # memory planning (optional) - copies = [(cast(Buffer,ji.bufs[0]),cast(Buffer,ji.bufs[1])) for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec))] - assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], copies, debug_prefix="JIT ") - jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache] - input_replace = get_input_replace(jit_cache, input_buffers) if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 25748a5a0a..783770e044 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -1,8 +1,6 @@ -from typing import cast from collections import defaultdict -from tinygrad.engine.realize import ExecItem -from tinygrad.device import Device, Buffer -from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up +from tinygrad.device import Device +from tinygrad.helpers import NO_MEMORY_PLANNER, DEBUG, round_up from tinygrad.uop.ops import UOp, Ops from tinygrad.dtype import dtypes from tinygrad.runtime.support.memory import TLSFAllocator @@ -12,71 +10,56 @@ def _collect_bufs(u:UOp) -> list[UOp]: if u.op in {Ops.MSELECT, Ops.MSTACK}: return [b for s in u.src for b in _collect_bufs(s)] return [] +def _can_plan(b:UOp, held_bufs:set[UOp]) -> bool: + if b in held_bufs: return False + devs = (b.device,) if isinstance(b.device, str) else b.device + return all(not d.startswith(("DISK", "TINYFS")) and hasattr(Device[d].allocator, "_offset") for d in devs) + LaneKey = tuple[str, int] -# **************** memory planning **************** +def memory_plan_rewrite(linear:UOp, held_bufs:set[UOp]|None=None) -> UOp: + if NO_MEMORY_PLANNER: return linear + if held_bufs is None: held_bufs = set() -def _internal_memory_planner(buffers:list[list[Buffer]], copies:list[tuple[Buffer, Buffer]]|None=None, - ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]: - if NO_MEMORY_PLANNER: return {} - first_appearance, last_appearance, buf_to_opt = {}, {}, set() - for i,u in enumerate(buffers): - for buf in u: - if not ignore_checks and (buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0): continue - if buf.base not in first_appearance: first_appearance[buf.base] = i - last_appearance[buf.base] = i - buf_to_opt.add(buf) + # compute lifetimes for all plannable internal buffers + first_appearance:dict[UOp, int] = {} + last_appearance:dict[UOp, int] = {} + copy_bufs: set[UOp] = set() + for i, si in enumerate(linear.src): + si_bufs = [b for src in si.src[1:] for b in _collect_bufs(src) if _can_plan(b, held_bufs)] + for b in si_bufs: + if b not in first_appearance: first_appearance[b] = i + last_appearance[b] = i + if si.src[0].op is Ops.COPY: copy_bufs.update(si_bufs) + if not first_appearance: return linear - # Separate copy and compute buffers into different lanes and defer cross-queue frees to avoid introducing dependencies (copy->compute->copy) - copy_dsts, copy_srcs = ({dst.base for dst,_ in copies}, {src.base for _,src in copies}) if copies else (set(), set()) - def _key(buf) -> LaneKey: return (buf.device, 1 if buf in copy_dsts or buf in copy_srcs else 0) - buf_hold = {buf: last_appearance[buf] - first_appearance[buf] + 1 for buf in first_appearance if buf in copy_dsts or buf in copy_srcs} + # separate copy and compute buffers into different lanes to avoid introducing dependencies (copy->compute->copy) + def _key(b:UOp): return (b.device, 1 if b in copy_bufs else 0) + buf_hold = {b: last_appearance[b] - first_appearance[b] + 1 for b in first_appearance if b in copy_bufs} - # Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed. - buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \ - [((last_appearance[buf] + 1 + buf_hold.get(buf, 0), False), buf) for buf in first_appearance.keys()], key=lambda x: x[0]) - total_memory = sum(round_up(buf.nbytes, BLK:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%) + # suballocation: build sorted open/close events, then alloc/free in order + block_size = 256 + nbytes = {b: round_up(b.arg * b.dtype.itemsize, block_size) for b in first_appearance} + events = sorted([(first_appearance[b], True, b) for b in first_appearance] + + [(last_appearance[b] + 1 + buf_hold.get(b, 0), False, b) for b in first_appearance], key=lambda x: (x[0], x[1])) + total_memory = sum(nbytes.values()) * 2 - # Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator. - # Also track buffer replacements for buffers that do not support suballocation. - buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {} - reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list) - global_planner:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=BLK, lv2_cnt=32))) - for (_, is_open_ev), buf in buffer_requests: - # Check if suballocation is possible for the given buffer and device. - if hasattr(Device[buf.device].allocator, "_offset"): - if is_open_ev: buffer_replace[buf] = (None, global_planner[_key(buf)][1].alloc(round_up(buf.nbytes, BLK))) - else: global_planner[_key(buf)][1].free(cast(int, buffer_replace[buf][1])) - global_planner[_key(buf)] = (max(global_planner[_key(buf)][0], buffer_replace[buf][1] + buf.nbytes), global_planner[_key(buf)][1]) - else: - key = (_key(buf), buf.dtype, buf.options, buf.nbytes) - if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None) - else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0])) + offsets:dict[UOp, int] = {} + peaks:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=block_size, lv2_cnt=32))) + for _, is_open, buf in events: + if is_open: offsets[buf] = peaks[_key(buf)][1].alloc(nbytes[buf]) + else: peaks[_key(buf)][1].free(offsets[buf]) + peaks[_key(buf)] = (max(peaks[_key(buf)][0], offsets[buf] + buf.arg * buf.dtype.itemsize), peaks[_key(buf)][1]) + arena_sizes = {key: round_up(peak, block_size) for key, (peak, _) in peaks.items()} - # Allocate global buffers based on the memory planner. - global_buffers = {key: Buffer(key[0], round_up(sz, BLK), dtypes.int8) for key, (sz, _) in global_planner.items()} - buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[_key(buf)], off) for buf,(base,off) in buffer_replace.items()} + # build replace_map: each buffer becomes a BUFFER_VIEW into a shared per-device-lane arena + arenas = {key: UOp.new_buffer(key[0], sz, dtypes.int8) for key, sz in arena_sizes.items()} + replace_map:dict[UOp, UOp] = {} + for buf_uop, offset in offsets.items(): + assert offset % buf_uop.dtype.itemsize == 0, f"offset {offset} not aligned to {buf_uop.dtype.itemsize}" + replace_map[buf_uop] = UOp(Ops.BUFFER_VIEW, buf_uop.dtype, (arenas[_key(buf_uop)],), (buf_uop.arg, offset // buf_uop.dtype.itemsize)) - # Assign buffers. First, assign full buffers (not sub-buffers). - assigned:dict[Buffer, Buffer] = {} - for buf, (base, off) in buffer_resolve.items(): - if buf != base: - assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off) + if DEBUG >= 1 and (omem:=sum(nbytes.values()) / 1e6) != (nmem:=sum(arena_sizes.values()) / 1e6): + print(f"memory reduced from {omem:.2f} MB -> {nmem:.2f} MB, {len(first_appearance)} -> {len(arenas)} bufs") - # Now assign sub-buffers. - for buf in buf_to_opt: - if buf._base is not None: - assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset) - - if DEBUG >= 1: - ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values()) - omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6 - if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs") - - return assigned - -def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]: - # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. - assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule], - copies=[(cast(Buffer,si.bufs[0]),cast(Buffer,si.bufs[1])) for si in schedule if si.ast.op is Ops.COPY]) - return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule] + return linear.substitute(replace_map, name="memory plan", walk=True) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8f19332b89..54043ed334 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -83,7 +83,7 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]: schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata)) return schedule -from tinygrad.engine.memory import memory_planner +from tinygrad.engine.memory import memory_plan_rewrite from tinygrad.engine.realize import capturing from tinygrad.schedule.rangeify import get_kernel_graph from tinygrad.helpers import CAPTURING @@ -163,7 +163,9 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], di capturing[0].add_linear(linear, var_vals) return [], var_vals + held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set()) + linear = memory_plan_rewrite(linear, held_bufs) + # convert LINEAR to ExecItems schedule: list[ExecItem] = linear_to_schedule(linear) - with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) return schedule, var_vals