mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
memplanner opt copy bufs (#15110)
* mtp * x * tests * ss * simp * less slop * x * cleaner * rm * m * c * x * f
This commit is contained in:
@@ -11,8 +11,8 @@ def b(i, base=None, offset=0, pin=False, size=16):
|
||||
if pin: global_map[i].ref(1)
|
||||
return global_map[i]
|
||||
|
||||
def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]]):
|
||||
assigned = _internal_memory_planner(buffers, noopt_buffers=None)
|
||||
def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]], copies:list[tuple[Buffer, Buffer]]|None=None):
|
||||
assigned = _internal_memory_planner(buffers, copies=copies)
|
||||
|
||||
taken_parts = set()
|
||||
first_appearance, last_appearance = {}, {}
|
||||
@@ -134,5 +134,75 @@ class TestMemoryPlanner(unittest.TestCase):
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
def test_copy_bufs_separate_from_compute(self):
|
||||
bs = [
|
||||
[b(0), b(1)],
|
||||
[b(1), b(2)],
|
||||
[b(3), b(2)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
|
||||
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
|
||||
assert r1.base != r2.base
|
||||
|
||||
def test_copy_bufs_reuse_among_copies(self):
|
||||
bs = [
|
||||
[b(0), b(1)],
|
||||
[b(2), b(1)],
|
||||
[b(3), b(2)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0)), (b(2), b(1))])
|
||||
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
|
||||
assert r1.base == r2.base
|
||||
|
||||
def test_compute_bufs_reuse_among_compute(self):
|
||||
bs = [
|
||||
[b(0), b(1)],
|
||||
[b(2), b(1)],
|
||||
[b(3), b(2)],
|
||||
[b(4), b(3)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
|
||||
r2, r3 = assigned.get(b(2), b(2)), assigned.get(b(3), b(3))
|
||||
assert r2.base == r3.base
|
||||
|
||||
def test_copy_and_compute_no_cross_reuse(self):
|
||||
bs = [
|
||||
[b(0), b(1)],
|
||||
[b(2), b(1)],
|
||||
[b(3), b(2)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(2), b(1))])
|
||||
r0, r2 = assigned.get(b(0), b(0)), assigned.get(b(2), b(2))
|
||||
assert r0.base != r2.base
|
||||
|
||||
def test_multiple_copy_bufs_with_offsets(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1), b(2)],
|
||||
[b(3, base=0, offset=1, size=8), b(1), b(2)],
|
||||
[b(4), b(3)],
|
||||
[b(5), b(4)],
|
||||
]
|
||||
check_assign(bs, copies=[(b(1), b(0)), (b(2), b(0))])
|
||||
|
||||
def test_copy_bufs_pinned_mixed(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1), b(2)],
|
||||
[b(1), b(3), b(2)],
|
||||
[b(4), b(3)],
|
||||
[b(5), b(4), b(0)],
|
||||
]
|
||||
check_assign(bs, copies=[(b(1), b(0)), (b(3), b(1))])
|
||||
|
||||
def test_deferred_copy_frees_chain(self):
|
||||
bs = []
|
||||
copies = []
|
||||
for i in range(6):
|
||||
copy_buf, compute_buf = b(i * 2 + 1), b(i * 2 + 2)
|
||||
bs.append([copy_buf, b(0, pin=True)])
|
||||
bs.append([compute_buf, copy_buf])
|
||||
copies.append((copy_buf, b(0, pin=True)))
|
||||
bs.append([b(100, pin=True)])
|
||||
check_assign(bs, copies=copies)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -108,7 +108,6 @@ class TestMultiRamUsage(unittest.TestCase):
|
||||
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
|
||||
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_multi_layer_allreduce(self):
|
||||
N = 32
|
||||
devices_2 = ("NULL:1", "NULL:2")
|
||||
|
||||
@@ -114,9 +114,9 @@ class GraphRunner(Runner):
|
||||
assert ji.prg.p.local_size is not None
|
||||
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
|
||||
|
||||
# used in MultiGraphRunner. the ints are id() of _bufs
|
||||
self.w_dependency_map: dict[int, Any] = {}
|
||||
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
|
||||
# used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
|
||||
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
|
||||
assert jit_cache[0].prg is not None
|
||||
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
||||
@@ -132,19 +132,22 @@ class GraphRunner(Runner):
|
||||
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
||||
|
||||
def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
|
||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
||||
wait_nodes = []
|
||||
|
||||
for i,buf in enumerate(bufs):
|
||||
if id(buf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(buf.base._buf)])
|
||||
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
|
||||
wait_nodes += [dep for st,en,dep in self.w_dependency_map[key] if st < e and s < en]
|
||||
if i in write: wait_nodes += [dep for st,en,dep in self.r_dependency_map[key] if st < e and s < en]
|
||||
for i,buf in enumerate(bufs):
|
||||
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
|
||||
if i in write:
|
||||
if id(buf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(buf.base._buf)))
|
||||
|
||||
for i,buf in enumerate(bufs):
|
||||
if i in write: self.w_dependency_map[id(buf.base._buf)] = new_dependency
|
||||
else: self.r_dependency_map[id(buf.base._buf)].append(new_dependency)
|
||||
|
||||
for dmap in [self.w_dependency_map, self.r_dependency_map]:
|
||||
kept = []
|
||||
for st,en,dep in dmap[key]:
|
||||
if st < min(s, en): kept.append((st, min(s, en), dep))
|
||||
if max(e, st) < en: kept.append((max(e, st), en, dep))
|
||||
dmap[key] = kept
|
||||
self.w_dependency_map[key].append((s, e, new_dependency))
|
||||
else: self.r_dependency_map[key].append((s, e, new_dependency))
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
@staticmethod
|
||||
@@ -357,9 +360,8 @@ class TinyJit(Generic[ReturnType]):
|
||||
jit_cache = pruned
|
||||
|
||||
# memory planning (optional)
|
||||
# Exclude buffers involved in transfer ops to preserve parallelism.
|
||||
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs}
|
||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
||||
copies = [(cast(Buffer,ji.bufs[0]),cast(Buffer,ji.bufs[1])) for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec))]
|
||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], copies, debug_prefix="JIT ")
|
||||
jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
||||
|
||||
input_replace = get_input_replace(jit_cache, input_buffers)
|
||||
|
||||
@@ -7,43 +7,50 @@ from tinygrad.uop.ops import Ops
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||
|
||||
LaneKey = tuple[str, int]
|
||||
|
||||
# **************** memory planning ****************
|
||||
|
||||
def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
|
||||
def _internal_memory_planner(buffers:list[list[Buffer]], copies:list[tuple[Buffer, Buffer]]|None=None,
|
||||
ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
|
||||
if NO_MEMORY_PLANNER: return {}
|
||||
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers)
|
||||
if not ignore_checks and should_skip: continue
|
||||
if not ignore_checks and (buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0): continue
|
||||
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
||||
last_appearance[buf.base] = i
|
||||
buf_to_opt.add(buf)
|
||||
|
||||
# Separate copy and compute buffers into different lanes and defer cross-queue frees to avoid introducing dependencies (copy->compute->copy)
|
||||
copy_dsts, copy_srcs = ({dst.base for dst,_ in copies}, {src.base for _,src in copies}) if copies else (set(), set())
|
||||
def _key(buf) -> LaneKey: return (buf.device, 1 if buf in copy_dsts or buf in copy_srcs else 0)
|
||||
buf_hold = {buf: last_appearance[buf] - first_appearance[buf] + 1 for buf in first_appearance if buf in copy_dsts or buf in copy_srcs}
|
||||
|
||||
# Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
|
||||
buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
|
||||
[((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
|
||||
total_memory = sum(round_up(buf.nbytes, min_block_size:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
|
||||
[((last_appearance[buf] + 1 + buf_hold.get(buf, 0), False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
|
||||
total_memory = sum(round_up(buf.nbytes, BLK:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
|
||||
|
||||
# Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
|
||||
# Also track buffer replacements for buffers that do not support suballocation.
|
||||
buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
|
||||
reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
|
||||
global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=min_block_size, lv2_cnt=32)))
|
||||
global_planner:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=BLK, lv2_cnt=32)))
|
||||
for (_, is_open_ev), buf in buffer_requests:
|
||||
# Check if suballocation is possible for the given buffer and device.
|
||||
if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType):
|
||||
if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000)))
|
||||
else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1]))
|
||||
global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1])
|
||||
if is_open_ev: buffer_replace[buf] = (None, global_planner[_key(buf)][1].alloc(round_up(buf.nbytes, BLK)))
|
||||
else: global_planner[_key(buf)][1].free(cast(int, buffer_replace[buf][1]))
|
||||
global_planner[_key(buf)] = (max(global_planner[_key(buf)][0], buffer_replace[buf][1] + buf.nbytes), global_planner[_key(buf)][1])
|
||||
else:
|
||||
key = (buf.device, buf.dtype, buf.options, buf.nbytes)
|
||||
key = (_key(buf), buf.dtype, buf.options, buf.nbytes)
|
||||
if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
|
||||
else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
|
||||
|
||||
# Allocate global buffers based on the memory planner.
|
||||
global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()}
|
||||
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()}
|
||||
global_buffers = {key: Buffer(key[0], round_up(sz, BLK), dtypes.int8) for key, (sz, _) in global_planner.items()}
|
||||
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[_key(buf)], off) for buf,(base,off) in buffer_replace.items()}
|
||||
|
||||
# Assign buffers. First, assign full buffers (not sub-buffers).
|
||||
assigned:dict[Buffer, Buffer] = {}
|
||||
@@ -66,5 +73,5 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
|
||||
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
|
||||
copies=[(cast(Buffer,si.bufs[0]),cast(Buffer,si.bufs[1])) for si in schedule if si.ast.op is Ops.COPY])
|
||||
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
|
||||
|
||||
Reference in New Issue
Block a user