diff --git a/test/null/test_memory_planner.py b/test/null/test_memory_planner.py index 32d3250820..b60a9f0f69 100644 --- a/test/null/test_memory_planner.py +++ b/test/null/test_memory_planner.py @@ -11,8 +11,8 @@ def b(i, base=None, offset=0, pin=False, size=16): if pin: global_map[i].ref(1) return global_map[i] -def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]]): - assigned = _internal_memory_planner(buffers, noopt_buffers=None) +def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]], copies:list[tuple[Buffer, Buffer]]|None=None): + assigned = _internal_memory_planner(buffers, copies=copies) taken_parts = set() first_appearance, last_appearance = {}, {} @@ -134,5 +134,75 @@ class TestMemoryPlanner(unittest.TestCase): ] check_assign(bs) + def test_copy_bufs_separate_from_compute(self): + bs = [ + [b(0), b(1)], + [b(1), b(2)], + [b(3), b(2)], + ] + assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))]) + r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2)) + assert r1.base != r2.base + + def test_copy_bufs_reuse_among_copies(self): + bs = [ + [b(0), b(1)], + [b(2), b(1)], + [b(3), b(2)], + ] + assigned = _internal_memory_planner(bs, copies=[(b(1), b(0)), (b(2), b(1))]) + r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2)) + assert r1.base == r2.base + + def test_compute_bufs_reuse_among_compute(self): + bs = [ + [b(0), b(1)], + [b(2), b(1)], + [b(3), b(2)], + [b(4), b(3)], + ] + assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))]) + r2, r3 = assigned.get(b(2), b(2)), assigned.get(b(3), b(3)) + assert r2.base == r3.base + + def test_copy_and_compute_no_cross_reuse(self): + bs = [ + [b(0), b(1)], + [b(2), b(1)], + [b(3), b(2)], + ] + assigned = _internal_memory_planner(bs, copies=[(b(2), b(1))]) + r0, r2 = assigned.get(b(0), b(0)), assigned.get(b(2), b(2)) + assert r0.base != r2.base + + def test_multiple_copy_bufs_with_offsets(self): + bs = [ + [b(0, pin=True), b(1), b(2)], + [b(3, base=0, offset=1, size=8), b(1), b(2)], + [b(4), b(3)], + [b(5), b(4)], + ] + check_assign(bs, copies=[(b(1), b(0)), (b(2), b(0))]) + + def test_copy_bufs_pinned_mixed(self): + bs = [ + [b(0, pin=True), b(1), b(2)], + [b(1), b(3), b(2)], + [b(4), b(3)], + [b(5), b(4), b(0)], + ] + check_assign(bs, copies=[(b(1), b(0)), (b(3), b(1))]) + + def test_deferred_copy_frees_chain(self): + bs = [] + copies = [] + for i in range(6): + copy_buf, compute_buf = b(i * 2 + 1), b(i * 2 + 2) + bs.append([copy_buf, b(0, pin=True)]) + bs.append([compute_buf, copy_buf]) + copies.append((copy_buf, b(0, pin=True))) + bs.append([b(100, pin=True)]) + check_assign(bs, copies=copies) + if __name__ == "__main__": unittest.main() diff --git a/test/null/test_multitensor.py b/test/null/test_multitensor.py index f2f746dae3..2b748386ce 100644 --- a/test/null/test_multitensor.py +++ b/test/null/test_multitensor.py @@ -108,7 +108,6 @@ class TestMultiRamUsage(unittest.TestCase): def test_matmul_half(self): self._test_matmul_half(dev_count=2) def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4) - @unittest.expectedFailure def test_multi_layer_allreduce(self): N = 32 devices_2 = ("NULL:1", "NULL:2") diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 5213a5f7fb..c7d307290e 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -114,9 +114,9 @@ class GraphRunner(Runner): assert ji.prg.p.local_size is not None self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size)) - # used in MultiGraphRunner. the ints are id() of _bufs - self.w_dependency_map: dict[int, Any] = {} - self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list) + # used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly. + self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list) + self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list) assert jit_cache[0].prg is not None super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify()) @@ -132,19 +132,22 @@ class GraphRunner(Runner): yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1]) def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any): - # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource, - # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers. wait_nodes = [] - for i,buf in enumerate(bufs): - if id(buf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(buf.base._buf)]) + key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes + wait_nodes += [dep for st,en,dep in self.w_dependency_map[key] if st < e and s < en] + if i in write: wait_nodes += [dep for st,en,dep in self.r_dependency_map[key] if st < e and s < en] + for i,buf in enumerate(bufs): + key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes if i in write: - if id(buf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(buf.base._buf))) - - for i,buf in enumerate(bufs): - if i in write: self.w_dependency_map[id(buf.base._buf)] = new_dependency - else: self.r_dependency_map[id(buf.base._buf)].append(new_dependency) - + for dmap in [self.w_dependency_map, self.r_dependency_map]: + kept = [] + for st,en,dep in dmap[key]: + if st < min(s, en): kept.append((st, min(s, en), dep)) + if max(e, st) < en: kept.append((max(e, st), en, dep)) + dmap[key] = kept + self.w_dependency_map[key].append((s, e, new_dependency)) + else: self.r_dependency_map[key].append((s, e, new_dependency)) return list({id(x):x for x in wait_nodes}.values()) @staticmethod @@ -357,9 +360,8 @@ class TinyJit(Generic[ReturnType]): jit_cache = pruned # memory planning (optional) - # Exclude buffers involved in transfer ops to preserve parallelism. - noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs} - assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ") + copies = [(cast(Buffer,ji.bufs[0]),cast(Buffer,ji.bufs[1])) for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec))] + assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], copies, debug_prefix="JIT ") jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache] input_replace = get_input_replace(jit_cache, input_buffers) diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index ae1d42544a..7da8585ea3 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -7,43 +7,50 @@ from tinygrad.uop.ops import Ops from tinygrad.dtype import dtypes, ImageDType from tinygrad.runtime.support.memory import TLSFAllocator +LaneKey = tuple[str, int] + # **************** memory planning **************** -def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]: +def _internal_memory_planner(buffers:list[list[Buffer]], copies:list[tuple[Buffer, Buffer]]|None=None, + ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]: if NO_MEMORY_PLANNER: return {} first_appearance, last_appearance, buf_to_opt = {}, {}, set() for i,u in enumerate(buffers): for buf in u: - should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers) - if not ignore_checks and should_skip: continue + if not ignore_checks and (buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0): continue if buf.base not in first_appearance: first_appearance[buf.base] = i last_appearance[buf.base] = i buf_to_opt.add(buf) + # Separate copy and compute buffers into different lanes and defer cross-queue frees to avoid introducing dependencies (copy->compute->copy) + copy_dsts, copy_srcs = ({dst.base for dst,_ in copies}, {src.base for _,src in copies}) if copies else (set(), set()) + def _key(buf) -> LaneKey: return (buf.device, 1 if buf in copy_dsts or buf in copy_srcs else 0) + buf_hold = {buf: last_appearance[buf] - first_appearance[buf] + 1 for buf in first_appearance if buf in copy_dsts or buf in copy_srcs} + # Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed. buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \ - [((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0]) - total_memory = sum(round_up(buf.nbytes, min_block_size:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%) + [((last_appearance[buf] + 1 + buf_hold.get(buf, 0), False), buf) for buf in first_appearance.keys()], key=lambda x: x[0]) + total_memory = sum(round_up(buf.nbytes, BLK:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%) # Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator. # Also track buffer replacements for buffers that do not support suballocation. buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {} reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list) - global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=min_block_size, lv2_cnt=32))) + global_planner:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=BLK, lv2_cnt=32))) for (_, is_open_ev), buf in buffer_requests: # Check if suballocation is possible for the given buffer and device. if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType): - if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000))) - else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1])) - global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1]) + if is_open_ev: buffer_replace[buf] = (None, global_planner[_key(buf)][1].alloc(round_up(buf.nbytes, BLK))) + else: global_planner[_key(buf)][1].free(cast(int, buffer_replace[buf][1])) + global_planner[_key(buf)] = (max(global_planner[_key(buf)][0], buffer_replace[buf][1] + buf.nbytes), global_planner[_key(buf)][1]) else: - key = (buf.device, buf.dtype, buf.options, buf.nbytes) + key = (_key(buf), buf.dtype, buf.options, buf.nbytes) if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None) else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0])) # Allocate global buffers based on the memory planner. - global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()} - buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()} + global_buffers = {key: Buffer(key[0], round_up(sz, BLK), dtypes.int8) for key, (sz, _) in global_planner.items()} + buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[_key(buf)], off) for buf,(base,off) in buffer_replace.items()} # Assign buffers. First, assign full buffers (not sub-buffers). assigned:dict[Buffer, Buffer] = {} @@ -66,5 +73,5 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule], - noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None}) + copies=[(cast(Buffer,si.bufs[0]),cast(Buffer,si.bufs[1])) for si in schedule if si.ast.op is Ops.COPY]) return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]