memplanner opt copy bufs (#15110)

* mtp

* x

* tests

* ss

* simp

* less slop

* x

* cleaner

* rm

* m

* c

* x

* f
This commit is contained in:
nimlgen
2026-03-08 22:28:01 +03:00
committed by GitHub
parent 633264feae
commit 6ac99fd4c9
4 changed files with 110 additions and 32 deletions

View File

@@ -11,8 +11,8 @@ def b(i, base=None, offset=0, pin=False, size=16):
if pin: global_map[i].ref(1)
return global_map[i]
def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]]):
assigned = _internal_memory_planner(buffers, noopt_buffers=None)
def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]], copies:list[tuple[Buffer, Buffer]]|None=None):
assigned = _internal_memory_planner(buffers, copies=copies)
taken_parts = set()
first_appearance, last_appearance = {}, {}
@@ -134,5 +134,75 @@ class TestMemoryPlanner(unittest.TestCase):
]
check_assign(bs)
def test_copy_bufs_separate_from_compute(self):
bs = [
[b(0), b(1)],
[b(1), b(2)],
[b(3), b(2)],
]
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
assert r1.base != r2.base
def test_copy_bufs_reuse_among_copies(self):
bs = [
[b(0), b(1)],
[b(2), b(1)],
[b(3), b(2)],
]
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0)), (b(2), b(1))])
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
assert r1.base == r2.base
def test_compute_bufs_reuse_among_compute(self):
bs = [
[b(0), b(1)],
[b(2), b(1)],
[b(3), b(2)],
[b(4), b(3)],
]
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
r2, r3 = assigned.get(b(2), b(2)), assigned.get(b(3), b(3))
assert r2.base == r3.base
def test_copy_and_compute_no_cross_reuse(self):
bs = [
[b(0), b(1)],
[b(2), b(1)],
[b(3), b(2)],
]
assigned = _internal_memory_planner(bs, copies=[(b(2), b(1))])
r0, r2 = assigned.get(b(0), b(0)), assigned.get(b(2), b(2))
assert r0.base != r2.base
def test_multiple_copy_bufs_with_offsets(self):
bs = [
[b(0, pin=True), b(1), b(2)],
[b(3, base=0, offset=1, size=8), b(1), b(2)],
[b(4), b(3)],
[b(5), b(4)],
]
check_assign(bs, copies=[(b(1), b(0)), (b(2), b(0))])
def test_copy_bufs_pinned_mixed(self):
bs = [
[b(0, pin=True), b(1), b(2)],
[b(1), b(3), b(2)],
[b(4), b(3)],
[b(5), b(4), b(0)],
]
check_assign(bs, copies=[(b(1), b(0)), (b(3), b(1))])
def test_deferred_copy_frees_chain(self):
bs = []
copies = []
for i in range(6):
copy_buf, compute_buf = b(i * 2 + 1), b(i * 2 + 2)
bs.append([copy_buf, b(0, pin=True)])
bs.append([compute_buf, copy_buf])
copies.append((copy_buf, b(0, pin=True)))
bs.append([b(100, pin=True)])
check_assign(bs, copies=copies)
if __name__ == "__main__":
unittest.main()

View File

@@ -108,7 +108,6 @@ class TestMultiRamUsage(unittest.TestCase):
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
@unittest.expectedFailure
def test_multi_layer_allreduce(self):
N = 32
devices_2 = ("NULL:1", "NULL:2")

View File

@@ -114,9 +114,9 @@ class GraphRunner(Runner):
assert ji.prg.p.local_size is not None
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
# used in MultiGraphRunner. the ints are id() of _bufs
self.w_dependency_map: dict[int, Any] = {}
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
# used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
assert jit_cache[0].prg is not None
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
@@ -132,19 +132,22 @@ class GraphRunner(Runner):
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
wait_nodes = []
for i,buf in enumerate(bufs):
if id(buf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(buf.base._buf)])
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
wait_nodes += [dep for st,en,dep in self.w_dependency_map[key] if st < e and s < en]
if i in write: wait_nodes += [dep for st,en,dep in self.r_dependency_map[key] if st < e and s < en]
for i,buf in enumerate(bufs):
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
if i in write:
if id(buf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(buf.base._buf)))
for i,buf in enumerate(bufs):
if i in write: self.w_dependency_map[id(buf.base._buf)] = new_dependency
else: self.r_dependency_map[id(buf.base._buf)].append(new_dependency)
for dmap in [self.w_dependency_map, self.r_dependency_map]:
kept = []
for st,en,dep in dmap[key]:
if st < min(s, en): kept.append((st, min(s, en), dep))
if max(e, st) < en: kept.append((max(e, st), en, dep))
dmap[key] = kept
self.w_dependency_map[key].append((s, e, new_dependency))
else: self.r_dependency_map[key].append((s, e, new_dependency))
return list({id(x):x for x in wait_nodes}.values())
@staticmethod
@@ -357,9 +360,8 @@ class TinyJit(Generic[ReturnType]):
jit_cache = pruned
# memory planning (optional)
# Exclude buffers involved in transfer ops to preserve parallelism.
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs}
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
copies = [(cast(Buffer,ji.bufs[0]),cast(Buffer,ji.bufs[1])) for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec))]
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], copies, debug_prefix="JIT ")
jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
input_replace = get_input_replace(jit_cache, input_buffers)

View File

@@ -7,43 +7,50 @@ from tinygrad.uop.ops import Ops
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.runtime.support.memory import TLSFAllocator
LaneKey = tuple[str, int]
# **************** memory planning ****************
def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
def _internal_memory_planner(buffers:list[list[Buffer]], copies:list[tuple[Buffer, Buffer]]|None=None,
ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
for i,u in enumerate(buffers):
for buf in u:
should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers)
if not ignore_checks and should_skip: continue
if not ignore_checks and (buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0): continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
buf_to_opt.add(buf)
# Separate copy and compute buffers into different lanes and defer cross-queue frees to avoid introducing dependencies (copy->compute->copy)
copy_dsts, copy_srcs = ({dst.base for dst,_ in copies}, {src.base for _,src in copies}) if copies else (set(), set())
def _key(buf) -> LaneKey: return (buf.device, 1 if buf in copy_dsts or buf in copy_srcs else 0)
buf_hold = {buf: last_appearance[buf] - first_appearance[buf] + 1 for buf in first_appearance if buf in copy_dsts or buf in copy_srcs}
# Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
[((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
total_memory = sum(round_up(buf.nbytes, min_block_size:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
[((last_appearance[buf] + 1 + buf_hold.get(buf, 0), False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
total_memory = sum(round_up(buf.nbytes, BLK:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
# Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
# Also track buffer replacements for buffers that do not support suballocation.
buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=min_block_size, lv2_cnt=32)))
global_planner:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=BLK, lv2_cnt=32)))
for (_, is_open_ev), buf in buffer_requests:
# Check if suballocation is possible for the given buffer and device.
if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType):
if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000)))
else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1]))
global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1])
if is_open_ev: buffer_replace[buf] = (None, global_planner[_key(buf)][1].alloc(round_up(buf.nbytes, BLK)))
else: global_planner[_key(buf)][1].free(cast(int, buffer_replace[buf][1]))
global_planner[_key(buf)] = (max(global_planner[_key(buf)][0], buffer_replace[buf][1] + buf.nbytes), global_planner[_key(buf)][1])
else:
key = (buf.device, buf.dtype, buf.options, buf.nbytes)
key = (_key(buf), buf.dtype, buf.options, buf.nbytes)
if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
# Allocate global buffers based on the memory planner.
global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()}
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()}
global_buffers = {key: Buffer(key[0], round_up(sz, BLK), dtypes.int8) for key, (sz, _) in global_planner.items()}
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[_key(buf)], off) for buf,(base,off) in buffer_replace.items()}
# Assign buffers. First, assign full buffers (not sub-buffers).
assigned:dict[Buffer, Buffer] = {}
@@ -66,5 +73,5 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
copies=[(cast(Buffer,si.bufs[0]),cast(Buffer,si.bufs[1])) for si in schedule if si.ast.op is Ops.COPY])
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]